Spaces:
Sleeping
Sleeping
| """ | |
| Hugging Face Space for STARFlow | |
| Text-to-Image and Text-to-Video Generation | |
| This app allows you to run STARFlow inference on Hugging Face GPU infrastructure. | |
| """ | |
| # Fix OpenMP warning - MUST be set BEFORE importing torch | |
| import os | |
| os.environ['OMP_NUM_THREADS'] = '1' | |
| os.environ['MKL_NUM_THREADS'] = '1' | |
| os.environ['NUMEXPR_NUM_THREADS'] = '1' | |
| # Fix CUDA memory fragmentation | |
| os.environ['PYTORCH_ALLOC_CONF'] = 'expandable_segments:True' | |
| import warnings | |
| import gradio as gr | |
| import torch | |
| import subprocess | |
| import pathlib | |
| import traceback | |
| from pathlib import Path | |
| # Suppress harmless warnings | |
| warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.distributed.reduce_op.*") | |
| # Try to import huggingface_hub for downloading checkpoints | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| HF_HUB_AVAILABLE = True | |
| except ImportError: | |
| HF_HUB_AVAILABLE = False | |
| print("β οΈ huggingface_hub not available. Install with: pip install huggingface_hub") | |
| # Check if running on Hugging Face Spaces | |
| HF_SPACE = os.environ.get("SPACE_ID") is not None | |
| # Import spaces module for ZeroGPU support (required for GPU allocation) | |
| try: | |
| import spaces | |
| SPACES_AVAILABLE = True | |
| except ImportError: | |
| SPACES_AVAILABLE = False | |
| print("β οΈ spaces module not available. GPU decorator will be skipped.") | |
| # Default checkpoint paths (if uploaded to Space Files) | |
| DEFAULT_IMAGE_CHECKPOINT = "ckpts/starflow_3B_t2i_256x256.pth" | |
| DEFAULT_VIDEO_CHECKPOINT = "ckpts/starflow-v_7B_t2v_caus_480p_v3.pth" | |
| # Model Hub repositories (if using Hugging Face Model Hub) | |
| # Set these to your Model Hub repo IDs if you upload checkpoints there | |
| # Format: "username/repo-name" | |
| IMAGE_CHECKPOINT_REPO = "GlobalStudio/starflow-3b-checkpoint" # Update this after creating Model Hub repo | |
| VIDEO_CHECKPOINT_REPO = "GlobalStudio/starflow-v-7b-checkpoint" # Update this after creating Model Hub repo | |
| def get_checkpoint_path(checkpoint_file, default_local_path, repo_id=None, filename=None): | |
| """Get checkpoint path, downloading from Hub if needed.""" | |
| # If user uploaded a file, use it | |
| if checkpoint_file is not None and checkpoint_file != "": | |
| if hasattr(checkpoint_file, 'name'): | |
| return checkpoint_file.name | |
| checkpoint_str = str(checkpoint_file) | |
| # If it's a file path that exists, use it | |
| if os.path.exists(checkpoint_str): | |
| return checkpoint_str | |
| # If it's the default path but doesn't exist, continue to download | |
| if checkpoint_str == default_local_path and not os.path.exists(checkpoint_str): | |
| pass # Continue to download logic below | |
| elif checkpoint_str != default_local_path: | |
| return checkpoint_str | |
| # Try local path first | |
| if os.path.exists(default_local_path): | |
| return default_local_path | |
| # Try downloading from Model Hub if configured | |
| if repo_id and filename and HF_HUB_AVAILABLE: | |
| try: | |
| # Use /workspace if available (HF Spaces persistent), otherwise use local cache | |
| if os.path.exists("/workspace"): | |
| cache_dir = "/workspace/checkpoints" | |
| elif os.path.exists("/tmp"): | |
| cache_dir = "/tmp/checkpoints" | |
| else: | |
| # Local development: use project directory or user cache | |
| cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "starflow") | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # Check if already downloaded | |
| possible_path = os.path.join(cache_dir, "models--" + repo_id.replace("/", "--"), "snapshots", "*", filename) | |
| import glob | |
| existing = glob.glob(possible_path) | |
| if existing: | |
| checkpoint_path = existing[0] | |
| print(f"β Using cached checkpoint: {checkpoint_path}") | |
| return checkpoint_path | |
| # Download with progress tracking | |
| import time | |
| start_time = time.time() | |
| print(f"π₯ Downloading checkpoint from {repo_id}...") | |
| print(f"File size: ~15.5 GB - This may take 10-30 minutes") | |
| print(f"Cache directory: {cache_dir}") | |
| print(f"Progress will be shown below...") | |
| # Use tqdm for progress if available | |
| try: | |
| from tqdm import tqdm | |
| from huggingface_hub.utils import tqdm as hf_tqdm | |
| checkpoint_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| cache_dir=cache_dir, | |
| local_files_only=False, | |
| resume_download=True, # Resume if interrupted | |
| ) | |
| except Exception as e: | |
| # Fallback if tqdm fails | |
| print(f"Note: Progress bar unavailable, downloading silently...") | |
| checkpoint_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| cache_dir=cache_dir, | |
| local_files_only=False, | |
| resume_download=True, | |
| ) | |
| elapsed = time.time() - start_time | |
| print(f"β Download completed in {elapsed/60:.1f} minutes") | |
| print(f"β Checkpoint at: {checkpoint_path}") | |
| return checkpoint_path | |
| except Exception as e: | |
| error_detail = str(e) | |
| if "404" in error_detail or "not found" in error_detail.lower(): | |
| return None, f"Checkpoint not found in Model Hub.\n\nPlease verify:\n1. Repo exists: https://huggingface.co/{repo_id}\n2. File exists: {filename}\n3. Repo is Public (not Private)\n\nError: {error_detail}" | |
| return None, f"Error downloading checkpoint: {error_detail}\n\nThis may take 10-30 minutes for a 14GB file. Please wait or check your internet connection." | |
| # No checkpoint found | |
| return None, f"Checkpoint not found. Please upload a checkpoint file or configure Model Hub repository." | |
| # Verify CUDA availability (will be True on HF Spaces with GPU hardware) | |
| if torch.cuda.is_available(): | |
| print(f"β CUDA available! Device: {torch.cuda.get_device_name(0)}") | |
| print(f" CUDA Version: {torch.version.cuda}") | |
| print(f" PyTorch Version: {torch.__version__}") | |
| print(f" GPU Count: {torch.cuda.device_count()}") | |
| print(f" Current Device: {torch.cuda.current_device()}") | |
| else: | |
| print("β οΈ CUDA not available. Make sure GPU hardware is selected in Space settings.") | |
| print(f" PyTorch Version: {torch.__version__}") | |
| # Apply @spaces.GPU decorator if available (required for ZeroGPU) | |
| # IMPORTANT: Decorator must be applied at module level for ZeroGPU to detect it at startup | |
| def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path): | |
| """Generate image from text prompt.""" | |
| return _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path) | |
| # Apply decorator if spaces module is available (ZeroGPU detection happens at import time) | |
| if SPACES_AVAILABLE and hasattr(spaces, 'GPU'): | |
| generate_image = spaces.GPU(generate_image) | |
| print("β ZeroGPU decorator applied to generate_image") | |
| def _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path): | |
| """Generate image from text prompt (implementation).""" | |
| # Get checkpoint path (from upload, local, or Model Hub) | |
| status_msg = "" | |
| # Handle checkpoint file (might be string from hidden Textbox) | |
| if checkpoint_file == DEFAULT_IMAGE_CHECKPOINT or checkpoint_file == "" or checkpoint_file is None: | |
| # Use Model Hub download | |
| checkpoint_file = None | |
| result = get_checkpoint_path( | |
| checkpoint_file, | |
| DEFAULT_IMAGE_CHECKPOINT, | |
| IMAGE_CHECKPOINT_REPO, | |
| "starflow_3B_t2i_256x256.pth" | |
| ) | |
| if isinstance(result, tuple) and result[0] is None: | |
| return None, result[1] # Error message | |
| checkpoint_path = result | |
| # Show status | |
| status_msg += f"Using checkpoint: {checkpoint_path}\n" | |
| if not os.path.exists(checkpoint_path): | |
| return None, f"Error: Checkpoint file not found at {checkpoint_path}.\n\nPlease verify:\n1. Model Hub repo exists: {IMAGE_CHECKPOINT_REPO}\n2. File name matches: starflow_3B_t2i_256x256.pth\n3. Repo is Public (not Private)\n\nAttempting to download from Model Hub..." | |
| if not config_path or not os.path.exists(config_path): | |
| return None, "Error: Config file not found. Please ensure config file exists." | |
| status_msg += "Starting image generation...\n" | |
| status_msg += "β±οΈ Timing breakdown:\n" | |
| status_msg += " - Checkpoint download: 10-30 min (first time only, ~15.5 GB)\n" | |
| status_msg += " - Model loading: 2-5 min (first time only)\n" | |
| status_msg += " - Image generation: 1-3 min\n" | |
| status_msg += " - Subsequent runs: Only generation time (~1-3 min)\n" | |
| try: | |
| # Create output directory (use /tmp for logs, outputs/ for images) | |
| # In HF Spaces, /tmp is accessible and outputs/ may not be visible in Files tab | |
| output_dir = Path("outputs") | |
| output_dir.mkdir(exist_ok=True) | |
| # Run sampling command | |
| cmd = [ | |
| "python", "sample.py", | |
| "--model_config_path", config_path, | |
| "--checkpoint_path", checkpoint_path, | |
| "--caption", prompt, | |
| "--sample_batch_size", "1", | |
| "--cfg", str(cfg), | |
| "--aspect_ratio", aspect_ratio, | |
| "--seed", str(seed), | |
| "--save_folder", "1", | |
| "--finetuned_vae", "none", | |
| "--jacobi", "1", | |
| "--jacobi_th", "0.001", | |
| "--jacobi_block_size", "16", | |
| "--logdir", str(output_dir) # Set logdir to outputs directory for images | |
| ] | |
| status_msg += "π Running generation...\n" | |
| status_msg += "π Current step: Model inference (checkpoint should already be downloaded)\n" | |
| # Note about log file location | |
| status_msg += f"\nπ LOGS:\n" | |
| status_msg += f" All logs will be shown in the status output below\n" | |
| status_msg += f" (Logs are captured in real-time)\n\n" | |
| # Ensure GPU environment variables are passed to subprocess | |
| env = os.environ.copy() | |
| # Preserve CUDA_VISIBLE_DEVICES if set (important for ZeroGPU) | |
| if 'CUDA_VISIBLE_DEVICES' in env: | |
| print(f"β CUDA_VISIBLE_DEVICES={env['CUDA_VISIBLE_DEVICES']}") | |
| # Verify GPU is available before starting | |
| if torch.cuda.is_available(): | |
| gpu_name = torch.cuda.get_device_name(0) | |
| total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| allocated = torch.cuda.memory_allocated(0) / 1024**3 | |
| reserved = torch.cuda.memory_reserved(0) / 1024**3 | |
| free_memory = total_memory - reserved | |
| status_msg += f"β GPU available: {gpu_name}\n" | |
| status_msg += f" Total Memory: {total_memory:.1f} GB\n" | |
| status_msg += f" Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB, Free: {free_memory:.2f} GB\n" | |
| # Warn if memory is low | |
| if free_memory < 2.0: | |
| status_msg += f"β οΈ Warning: Low GPU memory ({free_memory:.2f} GB free). Model may not fit.\n" | |
| else: | |
| status_msg += "β οΈ Warning: CUDA not available, will use CPU (very slow)\n" | |
| # Run with timeout (45 minutes max - allows for download + generation) | |
| # Capture output and write to log file | |
| # Note: If process is killed (e.g., GPU abort), we still capture what was output | |
| result = subprocess.run( | |
| cmd, | |
| capture_output=True, | |
| text=True, | |
| cwd=os.getcwd(), | |
| env=env, # Pass environment variables (including CUDA_VISIBLE_DEVICES) | |
| timeout=2700 | |
| ) | |
| # Build comprehensive log content for display (not relying on file access) | |
| log_content_parts = [] | |
| log_content_parts.append("=== GENERATION LOG ===\n\n") | |
| log_content_parts.append(f"Command: {' '.join(cmd)}\n\n") | |
| log_content_parts.append(f"Environment Variables:\n") | |
| log_content_parts.append(f" CUDA_VISIBLE_DEVICES={env.get('CUDA_VISIBLE_DEVICES', 'not set')}\n") | |
| log_content_parts.append(f" CUDA_AVAILABLE={torch.cuda.is_available()}\n") | |
| if torch.cuda.is_available(): | |
| log_content_parts.append(f" GPU_NAME={torch.cuda.get_device_name(0)}\n") | |
| log_content_parts.append(f" GPU_MEMORY_TOTAL={torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB\n") | |
| log_content_parts.append(f"\n") | |
| log_content_parts.append("=== STDOUT ===\n") | |
| log_content_parts.append(result.stdout if result.stdout else "(empty)\n") | |
| log_content_parts.append("\n\n=== STDERR ===\n") | |
| log_content_parts.append(result.stderr if result.stderr else "(empty)\n") | |
| log_content_parts.append(f"\n\n=== RETURN CODE: {result.returncode} ===\n") | |
| log_content = ''.join(log_content_parts) | |
| if result.returncode != 0: | |
| error_msg = f"β Error during generation (return code: {result.returncode})\n\n" | |
| error_msg += f"β οΈ IMPORTANT: Scroll down in this status box to see full error details!\n" | |
| error_msg += f" The error message below contains STDERR, STDOUT, and debugging info.\n\n" | |
| # Check for GPU abort or CUDA errors | |
| error_output = (result.stderr + result.stdout).lower() | |
| if "gpu aborted" in error_output or "cuda" in error_output or "out of memory" in error_output or "killed" in error_output: | |
| error_msg += "β οΈ GPU ERROR DETECTED\n\n" | |
| error_msg += "Possible causes:\n" | |
| error_msg += "1. GPU timeout (ZeroGPU may have a 5-10 min limit)\n" | |
| error_msg += "2. CUDA out of memory (model too large for GPU)\n" | |
| error_msg += "3. GPU allocation failed (ZeroGPU not detected)\n" | |
| error_msg += "4. Process killed due to memory limit\n\n" | |
| error_msg += "Solutions:\n" | |
| error_msg += "- Model is now using bfloat16/float16 for memory efficiency\n" | |
| error_msg += "- Try again (GPU may have been released)\n" | |
| error_msg += "- Check Space logs for detailed error\n" | |
| error_msg += "- Ensure @spaces.GPU decorator is applied\n" | |
| error_msg += "- Consider using paid GPU tier for longer runs\n" | |
| error_msg += "- If issue persists, model may be too large for available GPU\n\n" | |
| # Show detailed error information | |
| error_msg += f"\n{'='*80}\n" | |
| error_msg += f"π DETAILED ERROR LOGS\n" | |
| error_msg += f"{'='*80}\n\n" | |
| # Show return code and command | |
| error_msg += f"Return Code: {result.returncode}\n" | |
| error_msg += f"Command: {' '.join(cmd)}\n\n" | |
| # Show STDERR (usually contains the actual error) | |
| if result.stderr: | |
| error_msg += f"=== STDERR (Error Output) ===\n" | |
| error_msg += f"{result.stderr}\n\n" | |
| else: | |
| error_msg += f"β οΈ No STDERR output (process may have been killed silently)\n\n" | |
| # Show STDOUT (may contain useful info) | |
| if result.stdout: | |
| error_msg += f"=== STDOUT (Standard Output) ===\n" | |
| # Show last 5000 chars of stdout | |
| stdout_preview = result.stdout[-5000:] if len(result.stdout) > 5000 else result.stdout | |
| error_msg += f"{stdout_preview}\n\n" | |
| # Show full log content directly in error message (no file access needed) | |
| error_msg += f"=== FULL GENERATION LOG ===\n" | |
| error_msg += f"{log_content}\n\n" | |
| # Instructions on where to find more info | |
| error_msg += f"{'='*80}\n" | |
| error_msg += f"π ADDITIONAL DEBUGGING:\n" | |
| error_msg += f"{'='*80}\n" | |
| error_msg += f"1. Check the Space 'Logs' tab for container logs\n" | |
| error_msg += f"2. Look for messages from sample.py\n" | |
| error_msg += f"3. Check for GPU abort or CUDA errors\n" | |
| error_msg += f"4. All logs are shown above in this error message\n" | |
| error_msg += f"{'='*80}\n" | |
| return None, error_msg | |
| status_msg += "Generation complete. Looking for output...\n" | |
| status_msg += f"\nπ‘ Note: Generated images will appear directly in the UI above.\n" | |
| status_msg += f" The outputs/ folder is runtime-generated and not visible in Files tab.\n\n" | |
| # Find the generated image | |
| # The sample.py script saves to logdir/model_name/... | |
| # Model name is derived from checkpoint path stem | |
| checkpoint_stem = Path(checkpoint_path).stem | |
| model_output_dir = output_dir / checkpoint_stem | |
| status_msg += f"Searching in: {model_output_dir}\n" | |
| status_msg += f"Also searching recursively in: {output_dir}\n" | |
| # Search in model-specific directory first, then recursively | |
| search_paths = [model_output_dir, output_dir] | |
| output_files = [] | |
| for search_path in search_paths: | |
| if search_path.exists(): | |
| # Look for PNG, JPG, JPEG files | |
| found = list(search_path.glob("**/*.png")) + list(search_path.glob("**/*.jpg")) + list(search_path.glob("**/*.jpeg")) | |
| output_files.extend(found) | |
| status_msg += f"Found {len(found)} files in {search_path}\n" | |
| if output_files: | |
| # Get the most recent file | |
| latest_file = max(output_files, key=lambda p: p.stat().st_mtime) | |
| # Use absolute path for Gradio to access the file | |
| image_path = str(latest_file.absolute()) | |
| status_msg += f"β Found image: {image_path}\n" | |
| status_msg += f"\nπ‘ Note: Generated images appear directly in the UI above.\n" | |
| status_msg += f" The outputs/ folder is not visible in Files tab (runtime files).\n" | |
| return image_path, status_msg + "\nβ Success! Image generated." | |
| else: | |
| # Debug: list what's actually in the directory | |
| debug_info = f"\n\nDebug info:\n" | |
| debug_info += f"Output dir exists: {output_dir.exists()}\n" | |
| if output_dir.exists(): | |
| debug_info += f"Contents of {output_dir}:\n" | |
| for item in output_dir.iterdir(): | |
| debug_info += f" - {item.name} ({'dir' if item.is_dir() else 'file'})\n" | |
| if model_output_dir.exists(): | |
| debug_info += f"\nContents of {model_output_dir}:\n" | |
| for item in model_output_dir.iterdir(): | |
| debug_info += f" - {item.name} ({'dir' if item.is_dir() else 'file'})\n" | |
| error_msg = status_msg + f"Error: Generated image not found.\n" | |
| error_msg += f"Searched in: {output_dir} and {model_output_dir}\n" | |
| error_msg += debug_info | |
| if log_content: | |
| error_msg += f"\n\nπ Full log details:\n{log_content[-2000:]}" | |
| else: | |
| error_msg += f"\n\nCheck stdout:\n{result.stdout[-1000:] if result.stdout else '(no output)'}" | |
| return None, error_msg | |
| except subprocess.TimeoutExpired: | |
| error_msg = f"β Generation timed out after 45 minutes\n\n" | |
| error_msg += f"This may indicate:\n" | |
| error_msg += f"- Model loading is taking too long\n" | |
| error_msg += f"- GPU timeout (ZeroGPU may have limits)\n" | |
| error_msg += f"- Process hung or stuck\n\n" | |
| error_msg += f"Try:\n" | |
| error_msg += f"- Check Space Logs tab for more details\n" | |
| error_msg += f"- Try generating again\n" | |
| error_msg += f"- Check if GPU is still available\n" | |
| return None, error_msg | |
| except Exception as e: | |
| # Get full traceback for debugging | |
| error_traceback = traceback.format_exc() | |
| error_msg = f"β Exception occurred during image generation:\n\n" | |
| error_msg += f"Error Type: {type(e).__name__}\n" | |
| error_msg += f"Error Message: {str(e)}\n\n" | |
| error_msg += f"=== FULL TRACEBACK ===\n{error_traceback}\n\n" | |
| error_msg += f"π‘ TIP: Scroll down in this status box to see full error details.\n" | |
| error_msg += f" You can also copy the error message using the copy button.\n" | |
| return None, error_msg | |
| # Apply @spaces.GPU decorator if available (required for ZeroGPU) | |
| # IMPORTANT: Decorator must be applied at module level for ZeroGPU to detect it at startup | |
| def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image): | |
| """Generate video from text prompt.""" | |
| return _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image) | |
| # Apply decorator if spaces module is available (ZeroGPU detection happens at import time) | |
| if SPACES_AVAILABLE and hasattr(spaces, 'GPU'): | |
| generate_video = spaces.GPU(generate_video) | |
| print("β ZeroGPU decorator applied to generate_video") | |
| def _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image): | |
| """Generate video from text prompt (implementation).""" | |
| # Handle checkpoint file (might be string from hidden Textbox) | |
| if checkpoint_file == DEFAULT_VIDEO_CHECKPOINT or checkpoint_file == "" or checkpoint_file is None: | |
| # Use Model Hub download | |
| checkpoint_file = None | |
| # Get checkpoint path (from upload, local, or Model Hub) | |
| result = get_checkpoint_path( | |
| checkpoint_file, | |
| DEFAULT_VIDEO_CHECKPOINT, | |
| VIDEO_CHECKPOINT_REPO, | |
| "starflow-v_7B_t2v_caus_480p_v3.pth" | |
| ) | |
| if isinstance(result, tuple) and result[0] is None: | |
| return None, result[1] # Error message | |
| checkpoint_path = result | |
| if not os.path.exists(checkpoint_path): | |
| return None, f"Error: Checkpoint file not found at {checkpoint_path}." | |
| if not config_path or not os.path.exists(config_path): | |
| return None, "Error: Config file not found. Please ensure config file exists." | |
| # Handle input image | |
| input_image_path = None | |
| if input_image is not None: | |
| if hasattr(input_image, 'name'): | |
| input_image_path = input_image.name | |
| else: | |
| input_image_path = str(input_image) | |
| try: | |
| # Create output directory | |
| output_dir = Path("outputs") | |
| output_dir.mkdir(exist_ok=True) | |
| # Run sampling command | |
| cmd = [ | |
| "python", "sample.py", | |
| "--model_config_path", config_path, | |
| "--checkpoint_path", checkpoint_path, | |
| "--caption", prompt, | |
| "--sample_batch_size", "1", | |
| "--cfg", str(cfg), | |
| "--aspect_ratio", aspect_ratio, | |
| "--seed", str(seed), | |
| "--out_fps", "16", | |
| "--save_folder", "1", | |
| "--jacobi", "1", | |
| "--jacobi_th", "0.001", | |
| "--finetuned_vae", "none", | |
| "--disable_learnable_denoiser", "0", | |
| "--jacobi_block_size", "32", | |
| "--target_length", str(target_length) | |
| ] | |
| if input_image_path and os.path.exists(input_image_path): | |
| cmd.extend(["--input_image", input_image_path]) | |
| else: | |
| cmd.extend(["--input_image", "none"]) | |
| # Ensure GPU environment variables are passed to subprocess | |
| env = os.environ.copy() | |
| # Preserve CUDA_VISIBLE_DEVICES if set (important for ZeroGPU) | |
| if 'CUDA_VISIBLE_DEVICES' in env: | |
| print(f"β CUDA_VISIBLE_DEVICES={env['CUDA_VISIBLE_DEVICES']}") | |
| # Verify GPU is available before starting | |
| if torch.cuda.is_available(): | |
| print(f"β GPU available: {torch.cuda.get_device_name(0)}") | |
| result = subprocess.run( | |
| cmd, | |
| capture_output=True, | |
| text=True, | |
| cwd=os.getcwd(), | |
| env=env, # Pass environment variables (including CUDA_VISIBLE_DEVICES) | |
| timeout=3600 # 60 minutes for video generation | |
| ) | |
| if result.returncode != 0: | |
| error_msg = f"β Error during video generation (return code: {result.returncode})\n\n" | |
| # Check for GPU abort or CUDA errors | |
| error_output = (result.stderr + result.stdout).lower() | |
| if "gpu aborted" in error_output or "cuda" in error_output or "out of memory" in error_output: | |
| error_msg += "β οΈ GPU ERROR DETECTED\n\n" | |
| error_msg += "Possible causes:\n" | |
| error_msg += "1. GPU timeout (ZeroGPU may have a 5-10 min limit)\n" | |
| error_msg += "2. CUDA out of memory (video generation needs more GPU memory)\n" | |
| error_msg += "3. GPU allocation failed (ZeroGPU not detected)\n\n" | |
| error_msg += "Solutions:\n" | |
| error_msg += "- Try again (GPU may have been released)\n" | |
| error_msg += "- Check Space logs for detailed error\n" | |
| error_msg += "- Ensure @spaces.GPU decorator is applied\n" | |
| error_msg += "- Consider using paid GPU tier for longer runs\n\n" | |
| error_msg += f"=== STDERR ===\n{result.stderr}\n\n" | |
| error_msg += f"=== STDOUT ===\n{result.stdout}\n" | |
| return None, error_msg | |
| # Find the generated video | |
| output_files = list(output_dir.glob("**/*.mp4")) + list(output_dir.glob("**/*.gif")) | |
| if output_files: | |
| latest_file = max(output_files, key=lambda p: p.stat().st_mtime) | |
| return str(latest_file), "Success! Video generated." | |
| else: | |
| return None, "Error: Generated video not found." | |
| except Exception as e: | |
| # Get full traceback for debugging | |
| error_traceback = traceback.format_exc() | |
| error_msg = f"β Exception occurred during video generation:\n\n" | |
| error_msg += f"Error Type: {type(e).__name__}\n" | |
| error_msg += f"Error Message: {str(e)}\n\n" | |
| error_msg += f"=== FULL TRACEBACK ===\n{error_traceback}\n" | |
| return None, error_msg | |
| # Create Gradio interface | |
| with gr.Blocks(title="STARFlow - Text-to-Image & Video Generation") as demo: | |
| # Add custom CSS using gr.HTML | |
| gr.HTML(""" | |
| <style> | |
| .gradio-container { | |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; | |
| } | |
| .main-header { | |
| text-align: center; | |
| padding: 2rem 0; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| border-radius: 10px; | |
| margin-bottom: 2rem; | |
| } | |
| .main-header h1 { | |
| margin: 0; | |
| font-size: 2.5rem; | |
| font-weight: 700; | |
| text-shadow: 2px 2px 4px rgba(0,0,0,0.2); | |
| } | |
| .main-header p { | |
| margin: 0.5rem 0 0 0; | |
| font-size: 1.1rem; | |
| opacity: 0.95; | |
| } | |
| .info-box { | |
| background: #f0f4ff; | |
| border-left: 4px solid #667eea; | |
| padding: 1rem; | |
| border-radius: 5px; | |
| margin-bottom: 1.5rem; | |
| } | |
| .input-section { | |
| background: #fafafa; | |
| padding: 1.5rem; | |
| border-radius: 10px; | |
| border: 1px solid #e0e0e0; | |
| } | |
| .output-section { | |
| background: white; | |
| padding: 1.5rem; | |
| border-radius: 10px; | |
| border: 1px solid #e0e0e0; | |
| box-shadow: 0 2px 8px rgba(0,0,0,0.05); | |
| } | |
| .generate-btn { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| color: white !important; | |
| font-weight: 600 !important; | |
| padding: 0.75rem 2rem !important; | |
| border-radius: 8px !important; | |
| border: none !important; | |
| font-size: 1rem !important; | |
| transition: transform 0.2s, box-shadow 0.2s !important; | |
| } | |
| .generate-btn:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4) !important; | |
| } | |
| .status-box { | |
| font-family: 'Monaco', 'Menlo', monospace; | |
| font-size: 0.9rem; | |
| line-height: 1.6; | |
| } | |
| </style> | |
| """) | |
| # Header | |
| gr.HTML(""" | |
| <div class="main-header"> | |
| <h1>π¨ STARFlow</h1> | |
| <p>Scalable Transformer Auto-Regressive Flow</p> | |
| <p style="font-size: 0.95rem; margin-top: 0.5rem; opacity: 0.9;"> | |
| Generate high-quality images and videos from text prompts | |
| </p> | |
| </div> | |
| """) | |
| # Info box | |
| gr.Markdown(""" | |
| <div class="info-box"> | |
| <strong>βΉοΈ Note:</strong> Checkpoints are automatically downloaded from Model Hub on first use. | |
| First generation may take 10-20 minutes for download and model loading. | |
| </div> | |
| """) | |
| with gr.Tabs() as tabs: | |
| with gr.Tab("πΌοΈ Text-to-Image", id="image_tab"): | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=400): | |
| gr.Markdown("### βοΈ Generation Settings") | |
| with gr.Group(): | |
| image_prompt = gr.Textbox( | |
| label="π Prompt", | |
| placeholder="a film still of a cat playing piano", | |
| lines=4 | |
| ) | |
| image_config = gr.Textbox( | |
| label="βοΈ Config Path", | |
| value="configs/starflow_3B_t2i_256x256.yaml", | |
| interactive=False, | |
| visible=False # Hidden - not needed for users | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("#### π¨ Image Settings") | |
| image_aspect = gr.Dropdown( | |
| label="Aspect Ratio", | |
| choices=["1:1", "2:3", "3:2", "16:9", "9:16", "4:5", "5:4"], | |
| value="1:1" | |
| ) | |
| image_cfg = gr.Slider( | |
| label="CFG Scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| value=3.6, | |
| step=0.1 | |
| ) | |
| image_seed = gr.Number( | |
| value=999, | |
| label="π² Seed", | |
| precision=0 | |
| ) | |
| # Hidden checkpoint field | |
| image_checkpoint = gr.Textbox( | |
| label="Model Checkpoint Path (auto-downloaded)", | |
| value=DEFAULT_IMAGE_CHECKPOINT, | |
| visible=False | |
| ) | |
| image_btn = gr.Button( | |
| "β¨ Generate Image", | |
| variant="primary", | |
| size="lg", | |
| elem_classes="generate-btn" | |
| ) | |
| with gr.Column(scale=1, min_width=500): | |
| gr.Markdown("### π¨ Generated Image") | |
| image_output = gr.Image( | |
| label="", | |
| type="filepath", | |
| height=500, | |
| show_label=False | |
| ) | |
| image_status = gr.Textbox( | |
| label="π Status", | |
| lines=15, | |
| max_lines=50, | |
| interactive=False, | |
| elem_classes="status-box", | |
| placeholder="Status messages will appear here..." | |
| ) | |
| image_btn.click( | |
| fn=generate_image, | |
| inputs=[image_prompt, image_aspect, image_cfg, image_seed, image_checkpoint, image_config], | |
| outputs=[image_output, image_status], | |
| show_progress=True, | |
| queue=True # Use queue to handle long-running operations | |
| ) | |
| with gr.Tab("π¬ Text-to-Video", id="video_tab"): | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=400): | |
| gr.Markdown("### βοΈ Generation Settings") | |
| with gr.Group(): | |
| video_prompt = gr.Textbox( | |
| label="π Prompt", | |
| placeholder="a corgi dog looks at the camera", | |
| lines=4 | |
| ) | |
| video_config = gr.Textbox( | |
| label="βοΈ Config Path", | |
| value="configs/starflow-v_7B_t2v_caus_480p.yaml", | |
| interactive=False, | |
| visible=False # Hidden - not needed for users | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("#### π¬ Video Settings") | |
| video_aspect = gr.Dropdown( | |
| label="Aspect Ratio", | |
| choices=["16:9", "1:1", "4:3"], | |
| value="16:9" | |
| ) | |
| video_cfg = gr.Slider( | |
| label="CFG Scale", | |
| minimum=1.0, | |
| maximum=10.0, | |
| value=3.5, | |
| step=0.1 | |
| ) | |
| video_seed = gr.Number( | |
| value=99, | |
| label="π² Seed", | |
| precision=0 | |
| ) | |
| video_length = gr.Slider( | |
| label="Target Length (frames)", | |
| minimum=81, | |
| maximum=481, | |
| value=81, | |
| step=80 | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("#### πΌοΈ Optional Input") | |
| video_input_image = gr.Image( | |
| label="Input Image (optional)", | |
| type="filepath" | |
| ) | |
| # Hidden checkpoint field | |
| video_checkpoint = gr.Textbox( | |
| label="Model Checkpoint Path (auto-downloaded)", | |
| value=DEFAULT_VIDEO_CHECKPOINT, | |
| visible=False | |
| ) | |
| video_btn = gr.Button( | |
| "β¨ Generate Video", | |
| variant="primary", | |
| size="lg", | |
| elem_classes="generate-btn" | |
| ) | |
| with gr.Column(scale=1, min_width=500): | |
| gr.Markdown("### π¬ Generated Video") | |
| video_output = gr.Video( | |
| label="", | |
| height=500, | |
| show_label=False | |
| ) | |
| video_status = gr.Textbox( | |
| label="π Status", | |
| lines=15, | |
| max_lines=50, | |
| interactive=False, | |
| elem_classes="status-box", | |
| placeholder="Status messages will appear here..." | |
| ) | |
| video_btn.click( | |
| fn=generate_video, | |
| inputs=[video_prompt, video_aspect, video_cfg, video_seed, video_length, | |
| video_checkpoint, video_config, video_input_image], | |
| outputs=[video_output, video_status], | |
| show_progress=True, | |
| queue=True # Use queue to handle long-running operations | |
| ) | |
| if __name__ == "__main__": | |
| # Enable queue for long-running operations | |
| demo.queue(default_concurrency_limit=1) | |
| # Password protection - users don't need HF accounts! | |
| # Change these to your desired username/password | |
| # For multiple users, use: auth=[("user1", "pass1"), ("user2", "pass2")] | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| auth=("starflow", "im30"), # Change password! | |
| share=False, # Set to True if you want public Gradio link | |
| max_threads=1 # Limit threads for stability | |
| ) | |