starflow / app.py
leoeric's picture
Fix: Remove unsupported show_copy_button parameter and fix PYTORCH_ALLOC_CONF deprecation
d092df2
"""
Hugging Face Space for STARFlow
Text-to-Image and Text-to-Video Generation
This app allows you to run STARFlow inference on Hugging Face GPU infrastructure.
"""
# Fix OpenMP warning - MUST be set BEFORE importing torch
import os
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
# Fix CUDA memory fragmentation
os.environ['PYTORCH_ALLOC_CONF'] = 'expandable_segments:True'
import warnings
import gradio as gr
import torch
import subprocess
import pathlib
import traceback
from pathlib import Path
# Suppress harmless warnings
warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.distributed.reduce_op.*")
# Try to import huggingface_hub for downloading checkpoints
try:
from huggingface_hub import hf_hub_download
HF_HUB_AVAILABLE = True
except ImportError:
HF_HUB_AVAILABLE = False
print("⚠️ huggingface_hub not available. Install with: pip install huggingface_hub")
# Check if running on Hugging Face Spaces
HF_SPACE = os.environ.get("SPACE_ID") is not None
# Import spaces module for ZeroGPU support (required for GPU allocation)
try:
import spaces
SPACES_AVAILABLE = True
except ImportError:
SPACES_AVAILABLE = False
print("⚠️ spaces module not available. GPU decorator will be skipped.")
# Default checkpoint paths (if uploaded to Space Files)
DEFAULT_IMAGE_CHECKPOINT = "ckpts/starflow_3B_t2i_256x256.pth"
DEFAULT_VIDEO_CHECKPOINT = "ckpts/starflow-v_7B_t2v_caus_480p_v3.pth"
# Model Hub repositories (if using Hugging Face Model Hub)
# Set these to your Model Hub repo IDs if you upload checkpoints there
# Format: "username/repo-name"
IMAGE_CHECKPOINT_REPO = "GlobalStudio/starflow-3b-checkpoint" # Update this after creating Model Hub repo
VIDEO_CHECKPOINT_REPO = "GlobalStudio/starflow-v-7b-checkpoint" # Update this after creating Model Hub repo
def get_checkpoint_path(checkpoint_file, default_local_path, repo_id=None, filename=None):
"""Get checkpoint path, downloading from Hub if needed."""
# If user uploaded a file, use it
if checkpoint_file is not None and checkpoint_file != "":
if hasattr(checkpoint_file, 'name'):
return checkpoint_file.name
checkpoint_str = str(checkpoint_file)
# If it's a file path that exists, use it
if os.path.exists(checkpoint_str):
return checkpoint_str
# If it's the default path but doesn't exist, continue to download
if checkpoint_str == default_local_path and not os.path.exists(checkpoint_str):
pass # Continue to download logic below
elif checkpoint_str != default_local_path:
return checkpoint_str
# Try local path first
if os.path.exists(default_local_path):
return default_local_path
# Try downloading from Model Hub if configured
if repo_id and filename and HF_HUB_AVAILABLE:
try:
# Use /workspace if available (HF Spaces persistent), otherwise use local cache
if os.path.exists("/workspace"):
cache_dir = "/workspace/checkpoints"
elif os.path.exists("/tmp"):
cache_dir = "/tmp/checkpoints"
else:
# Local development: use project directory or user cache
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "starflow")
os.makedirs(cache_dir, exist_ok=True)
# Check if already downloaded
possible_path = os.path.join(cache_dir, "models--" + repo_id.replace("/", "--"), "snapshots", "*", filename)
import glob
existing = glob.glob(possible_path)
if existing:
checkpoint_path = existing[0]
print(f"βœ… Using cached checkpoint: {checkpoint_path}")
return checkpoint_path
# Download with progress tracking
import time
start_time = time.time()
print(f"πŸ“₯ Downloading checkpoint from {repo_id}...")
print(f"File size: ~15.5 GB - This may take 10-30 minutes")
print(f"Cache directory: {cache_dir}")
print(f"Progress will be shown below...")
# Use tqdm for progress if available
try:
from tqdm import tqdm
from huggingface_hub.utils import tqdm as hf_tqdm
checkpoint_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
cache_dir=cache_dir,
local_files_only=False,
resume_download=True, # Resume if interrupted
)
except Exception as e:
# Fallback if tqdm fails
print(f"Note: Progress bar unavailable, downloading silently...")
checkpoint_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
cache_dir=cache_dir,
local_files_only=False,
resume_download=True,
)
elapsed = time.time() - start_time
print(f"βœ… Download completed in {elapsed/60:.1f} minutes")
print(f"βœ… Checkpoint at: {checkpoint_path}")
return checkpoint_path
except Exception as e:
error_detail = str(e)
if "404" in error_detail or "not found" in error_detail.lower():
return None, f"Checkpoint not found in Model Hub.\n\nPlease verify:\n1. Repo exists: https://huggingface.co/{repo_id}\n2. File exists: {filename}\n3. Repo is Public (not Private)\n\nError: {error_detail}"
return None, f"Error downloading checkpoint: {error_detail}\n\nThis may take 10-30 minutes for a 14GB file. Please wait or check your internet connection."
# No checkpoint found
return None, f"Checkpoint not found. Please upload a checkpoint file or configure Model Hub repository."
# Verify CUDA availability (will be True on HF Spaces with GPU hardware)
if torch.cuda.is_available():
print(f"βœ… CUDA available! Device: {torch.cuda.get_device_name(0)}")
print(f" CUDA Version: {torch.version.cuda}")
print(f" PyTorch Version: {torch.__version__}")
print(f" GPU Count: {torch.cuda.device_count()}")
print(f" Current Device: {torch.cuda.current_device()}")
else:
print("⚠️ CUDA not available. Make sure GPU hardware is selected in Space settings.")
print(f" PyTorch Version: {torch.__version__}")
# Apply @spaces.GPU decorator if available (required for ZeroGPU)
# IMPORTANT: Decorator must be applied at module level for ZeroGPU to detect it at startup
def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
"""Generate image from text prompt."""
return _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path)
# Apply decorator if spaces module is available (ZeroGPU detection happens at import time)
if SPACES_AVAILABLE and hasattr(spaces, 'GPU'):
generate_image = spaces.GPU(generate_image)
print("βœ… ZeroGPU decorator applied to generate_image")
def _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
"""Generate image from text prompt (implementation)."""
# Get checkpoint path (from upload, local, or Model Hub)
status_msg = ""
# Handle checkpoint file (might be string from hidden Textbox)
if checkpoint_file == DEFAULT_IMAGE_CHECKPOINT or checkpoint_file == "" or checkpoint_file is None:
# Use Model Hub download
checkpoint_file = None
result = get_checkpoint_path(
checkpoint_file,
DEFAULT_IMAGE_CHECKPOINT,
IMAGE_CHECKPOINT_REPO,
"starflow_3B_t2i_256x256.pth"
)
if isinstance(result, tuple) and result[0] is None:
return None, result[1] # Error message
checkpoint_path = result
# Show status
status_msg += f"Using checkpoint: {checkpoint_path}\n"
if not os.path.exists(checkpoint_path):
return None, f"Error: Checkpoint file not found at {checkpoint_path}.\n\nPlease verify:\n1. Model Hub repo exists: {IMAGE_CHECKPOINT_REPO}\n2. File name matches: starflow_3B_t2i_256x256.pth\n3. Repo is Public (not Private)\n\nAttempting to download from Model Hub..."
if not config_path or not os.path.exists(config_path):
return None, "Error: Config file not found. Please ensure config file exists."
status_msg += "Starting image generation...\n"
status_msg += "⏱️ Timing breakdown:\n"
status_msg += " - Checkpoint download: 10-30 min (first time only, ~15.5 GB)\n"
status_msg += " - Model loading: 2-5 min (first time only)\n"
status_msg += " - Image generation: 1-3 min\n"
status_msg += " - Subsequent runs: Only generation time (~1-3 min)\n"
try:
# Create output directory (use /tmp for logs, outputs/ for images)
# In HF Spaces, /tmp is accessible and outputs/ may not be visible in Files tab
output_dir = Path("outputs")
output_dir.mkdir(exist_ok=True)
# Run sampling command
cmd = [
"python", "sample.py",
"--model_config_path", config_path,
"--checkpoint_path", checkpoint_path,
"--caption", prompt,
"--sample_batch_size", "1",
"--cfg", str(cfg),
"--aspect_ratio", aspect_ratio,
"--seed", str(seed),
"--save_folder", "1",
"--finetuned_vae", "none",
"--jacobi", "1",
"--jacobi_th", "0.001",
"--jacobi_block_size", "16",
"--logdir", str(output_dir) # Set logdir to outputs directory for images
]
status_msg += "πŸš€ Running generation...\n"
status_msg += "πŸ“Š Current step: Model inference (checkpoint should already be downloaded)\n"
# Note about log file location
status_msg += f"\nπŸ“‹ LOGS:\n"
status_msg += f" All logs will be shown in the status output below\n"
status_msg += f" (Logs are captured in real-time)\n\n"
# Ensure GPU environment variables are passed to subprocess
env = os.environ.copy()
# Preserve CUDA_VISIBLE_DEVICES if set (important for ZeroGPU)
if 'CUDA_VISIBLE_DEVICES' in env:
print(f"βœ… CUDA_VISIBLE_DEVICES={env['CUDA_VISIBLE_DEVICES']}")
# Verify GPU is available before starting
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
allocated = torch.cuda.memory_allocated(0) / 1024**3
reserved = torch.cuda.memory_reserved(0) / 1024**3
free_memory = total_memory - reserved
status_msg += f"βœ… GPU available: {gpu_name}\n"
status_msg += f" Total Memory: {total_memory:.1f} GB\n"
status_msg += f" Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB, Free: {free_memory:.2f} GB\n"
# Warn if memory is low
if free_memory < 2.0:
status_msg += f"⚠️ Warning: Low GPU memory ({free_memory:.2f} GB free). Model may not fit.\n"
else:
status_msg += "⚠️ Warning: CUDA not available, will use CPU (very slow)\n"
# Run with timeout (45 minutes max - allows for download + generation)
# Capture output and write to log file
# Note: If process is killed (e.g., GPU abort), we still capture what was output
result = subprocess.run(
cmd,
capture_output=True,
text=True,
cwd=os.getcwd(),
env=env, # Pass environment variables (including CUDA_VISIBLE_DEVICES)
timeout=2700
)
# Build comprehensive log content for display (not relying on file access)
log_content_parts = []
log_content_parts.append("=== GENERATION LOG ===\n\n")
log_content_parts.append(f"Command: {' '.join(cmd)}\n\n")
log_content_parts.append(f"Environment Variables:\n")
log_content_parts.append(f" CUDA_VISIBLE_DEVICES={env.get('CUDA_VISIBLE_DEVICES', 'not set')}\n")
log_content_parts.append(f" CUDA_AVAILABLE={torch.cuda.is_available()}\n")
if torch.cuda.is_available():
log_content_parts.append(f" GPU_NAME={torch.cuda.get_device_name(0)}\n")
log_content_parts.append(f" GPU_MEMORY_TOTAL={torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB\n")
log_content_parts.append(f"\n")
log_content_parts.append("=== STDOUT ===\n")
log_content_parts.append(result.stdout if result.stdout else "(empty)\n")
log_content_parts.append("\n\n=== STDERR ===\n")
log_content_parts.append(result.stderr if result.stderr else "(empty)\n")
log_content_parts.append(f"\n\n=== RETURN CODE: {result.returncode} ===\n")
log_content = ''.join(log_content_parts)
if result.returncode != 0:
error_msg = f"❌ Error during generation (return code: {result.returncode})\n\n"
error_msg += f"⚠️ IMPORTANT: Scroll down in this status box to see full error details!\n"
error_msg += f" The error message below contains STDERR, STDOUT, and debugging info.\n\n"
# Check for GPU abort or CUDA errors
error_output = (result.stderr + result.stdout).lower()
if "gpu aborted" in error_output or "cuda" in error_output or "out of memory" in error_output or "killed" in error_output:
error_msg += "⚠️ GPU ERROR DETECTED\n\n"
error_msg += "Possible causes:\n"
error_msg += "1. GPU timeout (ZeroGPU may have a 5-10 min limit)\n"
error_msg += "2. CUDA out of memory (model too large for GPU)\n"
error_msg += "3. GPU allocation failed (ZeroGPU not detected)\n"
error_msg += "4. Process killed due to memory limit\n\n"
error_msg += "Solutions:\n"
error_msg += "- Model is now using bfloat16/float16 for memory efficiency\n"
error_msg += "- Try again (GPU may have been released)\n"
error_msg += "- Check Space logs for detailed error\n"
error_msg += "- Ensure @spaces.GPU decorator is applied\n"
error_msg += "- Consider using paid GPU tier for longer runs\n"
error_msg += "- If issue persists, model may be too large for available GPU\n\n"
# Show detailed error information
error_msg += f"\n{'='*80}\n"
error_msg += f"πŸ“‹ DETAILED ERROR LOGS\n"
error_msg += f"{'='*80}\n\n"
# Show return code and command
error_msg += f"Return Code: {result.returncode}\n"
error_msg += f"Command: {' '.join(cmd)}\n\n"
# Show STDERR (usually contains the actual error)
if result.stderr:
error_msg += f"=== STDERR (Error Output) ===\n"
error_msg += f"{result.stderr}\n\n"
else:
error_msg += f"⚠️ No STDERR output (process may have been killed silently)\n\n"
# Show STDOUT (may contain useful info)
if result.stdout:
error_msg += f"=== STDOUT (Standard Output) ===\n"
# Show last 5000 chars of stdout
stdout_preview = result.stdout[-5000:] if len(result.stdout) > 5000 else result.stdout
error_msg += f"{stdout_preview}\n\n"
# Show full log content directly in error message (no file access needed)
error_msg += f"=== FULL GENERATION LOG ===\n"
error_msg += f"{log_content}\n\n"
# Instructions on where to find more info
error_msg += f"{'='*80}\n"
error_msg += f"πŸ“ ADDITIONAL DEBUGGING:\n"
error_msg += f"{'='*80}\n"
error_msg += f"1. Check the Space 'Logs' tab for container logs\n"
error_msg += f"2. Look for messages from sample.py\n"
error_msg += f"3. Check for GPU abort or CUDA errors\n"
error_msg += f"4. All logs are shown above in this error message\n"
error_msg += f"{'='*80}\n"
return None, error_msg
status_msg += "Generation complete. Looking for output...\n"
status_msg += f"\nπŸ’‘ Note: Generated images will appear directly in the UI above.\n"
status_msg += f" The outputs/ folder is runtime-generated and not visible in Files tab.\n\n"
# Find the generated image
# The sample.py script saves to logdir/model_name/...
# Model name is derived from checkpoint path stem
checkpoint_stem = Path(checkpoint_path).stem
model_output_dir = output_dir / checkpoint_stem
status_msg += f"Searching in: {model_output_dir}\n"
status_msg += f"Also searching recursively in: {output_dir}\n"
# Search in model-specific directory first, then recursively
search_paths = [model_output_dir, output_dir]
output_files = []
for search_path in search_paths:
if search_path.exists():
# Look for PNG, JPG, JPEG files
found = list(search_path.glob("**/*.png")) + list(search_path.glob("**/*.jpg")) + list(search_path.glob("**/*.jpeg"))
output_files.extend(found)
status_msg += f"Found {len(found)} files in {search_path}\n"
if output_files:
# Get the most recent file
latest_file = max(output_files, key=lambda p: p.stat().st_mtime)
# Use absolute path for Gradio to access the file
image_path = str(latest_file.absolute())
status_msg += f"βœ… Found image: {image_path}\n"
status_msg += f"\nπŸ’‘ Note: Generated images appear directly in the UI above.\n"
status_msg += f" The outputs/ folder is not visible in Files tab (runtime files).\n"
return image_path, status_msg + "\nβœ… Success! Image generated."
else:
# Debug: list what's actually in the directory
debug_info = f"\n\nDebug info:\n"
debug_info += f"Output dir exists: {output_dir.exists()}\n"
if output_dir.exists():
debug_info += f"Contents of {output_dir}:\n"
for item in output_dir.iterdir():
debug_info += f" - {item.name} ({'dir' if item.is_dir() else 'file'})\n"
if model_output_dir.exists():
debug_info += f"\nContents of {model_output_dir}:\n"
for item in model_output_dir.iterdir():
debug_info += f" - {item.name} ({'dir' if item.is_dir() else 'file'})\n"
error_msg = status_msg + f"Error: Generated image not found.\n"
error_msg += f"Searched in: {output_dir} and {model_output_dir}\n"
error_msg += debug_info
if log_content:
error_msg += f"\n\nπŸ“‹ Full log details:\n{log_content[-2000:]}"
else:
error_msg += f"\n\nCheck stdout:\n{result.stdout[-1000:] if result.stdout else '(no output)'}"
return None, error_msg
except subprocess.TimeoutExpired:
error_msg = f"❌ Generation timed out after 45 minutes\n\n"
error_msg += f"This may indicate:\n"
error_msg += f"- Model loading is taking too long\n"
error_msg += f"- GPU timeout (ZeroGPU may have limits)\n"
error_msg += f"- Process hung or stuck\n\n"
error_msg += f"Try:\n"
error_msg += f"- Check Space Logs tab for more details\n"
error_msg += f"- Try generating again\n"
error_msg += f"- Check if GPU is still available\n"
return None, error_msg
except Exception as e:
# Get full traceback for debugging
error_traceback = traceback.format_exc()
error_msg = f"❌ Exception occurred during image generation:\n\n"
error_msg += f"Error Type: {type(e).__name__}\n"
error_msg += f"Error Message: {str(e)}\n\n"
error_msg += f"=== FULL TRACEBACK ===\n{error_traceback}\n\n"
error_msg += f"πŸ’‘ TIP: Scroll down in this status box to see full error details.\n"
error_msg += f" You can also copy the error message using the copy button.\n"
return None, error_msg
# Apply @spaces.GPU decorator if available (required for ZeroGPU)
# IMPORTANT: Decorator must be applied at module level for ZeroGPU to detect it at startup
def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
"""Generate video from text prompt."""
return _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image)
# Apply decorator if spaces module is available (ZeroGPU detection happens at import time)
if SPACES_AVAILABLE and hasattr(spaces, 'GPU'):
generate_video = spaces.GPU(generate_video)
print("βœ… ZeroGPU decorator applied to generate_video")
def _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
"""Generate video from text prompt (implementation)."""
# Handle checkpoint file (might be string from hidden Textbox)
if checkpoint_file == DEFAULT_VIDEO_CHECKPOINT or checkpoint_file == "" or checkpoint_file is None:
# Use Model Hub download
checkpoint_file = None
# Get checkpoint path (from upload, local, or Model Hub)
result = get_checkpoint_path(
checkpoint_file,
DEFAULT_VIDEO_CHECKPOINT,
VIDEO_CHECKPOINT_REPO,
"starflow-v_7B_t2v_caus_480p_v3.pth"
)
if isinstance(result, tuple) and result[0] is None:
return None, result[1] # Error message
checkpoint_path = result
if not os.path.exists(checkpoint_path):
return None, f"Error: Checkpoint file not found at {checkpoint_path}."
if not config_path or not os.path.exists(config_path):
return None, "Error: Config file not found. Please ensure config file exists."
# Handle input image
input_image_path = None
if input_image is not None:
if hasattr(input_image, 'name'):
input_image_path = input_image.name
else:
input_image_path = str(input_image)
try:
# Create output directory
output_dir = Path("outputs")
output_dir.mkdir(exist_ok=True)
# Run sampling command
cmd = [
"python", "sample.py",
"--model_config_path", config_path,
"--checkpoint_path", checkpoint_path,
"--caption", prompt,
"--sample_batch_size", "1",
"--cfg", str(cfg),
"--aspect_ratio", aspect_ratio,
"--seed", str(seed),
"--out_fps", "16",
"--save_folder", "1",
"--jacobi", "1",
"--jacobi_th", "0.001",
"--finetuned_vae", "none",
"--disable_learnable_denoiser", "0",
"--jacobi_block_size", "32",
"--target_length", str(target_length)
]
if input_image_path and os.path.exists(input_image_path):
cmd.extend(["--input_image", input_image_path])
else:
cmd.extend(["--input_image", "none"])
# Ensure GPU environment variables are passed to subprocess
env = os.environ.copy()
# Preserve CUDA_VISIBLE_DEVICES if set (important for ZeroGPU)
if 'CUDA_VISIBLE_DEVICES' in env:
print(f"βœ… CUDA_VISIBLE_DEVICES={env['CUDA_VISIBLE_DEVICES']}")
# Verify GPU is available before starting
if torch.cuda.is_available():
print(f"βœ… GPU available: {torch.cuda.get_device_name(0)}")
result = subprocess.run(
cmd,
capture_output=True,
text=True,
cwd=os.getcwd(),
env=env, # Pass environment variables (including CUDA_VISIBLE_DEVICES)
timeout=3600 # 60 minutes for video generation
)
if result.returncode != 0:
error_msg = f"❌ Error during video generation (return code: {result.returncode})\n\n"
# Check for GPU abort or CUDA errors
error_output = (result.stderr + result.stdout).lower()
if "gpu aborted" in error_output or "cuda" in error_output or "out of memory" in error_output:
error_msg += "⚠️ GPU ERROR DETECTED\n\n"
error_msg += "Possible causes:\n"
error_msg += "1. GPU timeout (ZeroGPU may have a 5-10 min limit)\n"
error_msg += "2. CUDA out of memory (video generation needs more GPU memory)\n"
error_msg += "3. GPU allocation failed (ZeroGPU not detected)\n\n"
error_msg += "Solutions:\n"
error_msg += "- Try again (GPU may have been released)\n"
error_msg += "- Check Space logs for detailed error\n"
error_msg += "- Ensure @spaces.GPU decorator is applied\n"
error_msg += "- Consider using paid GPU tier for longer runs\n\n"
error_msg += f"=== STDERR ===\n{result.stderr}\n\n"
error_msg += f"=== STDOUT ===\n{result.stdout}\n"
return None, error_msg
# Find the generated video
output_files = list(output_dir.glob("**/*.mp4")) + list(output_dir.glob("**/*.gif"))
if output_files:
latest_file = max(output_files, key=lambda p: p.stat().st_mtime)
return str(latest_file), "Success! Video generated."
else:
return None, "Error: Generated video not found."
except Exception as e:
# Get full traceback for debugging
error_traceback = traceback.format_exc()
error_msg = f"❌ Exception occurred during video generation:\n\n"
error_msg += f"Error Type: {type(e).__name__}\n"
error_msg += f"Error Message: {str(e)}\n\n"
error_msg += f"=== FULL TRACEBACK ===\n{error_traceback}\n"
return None, error_msg
# Create Gradio interface
with gr.Blocks(title="STARFlow - Text-to-Image & Video Generation") as demo:
# Add custom CSS using gr.HTML
gr.HTML("""
<style>
.gradio-container {
font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
}
.main-header {
text-align: center;
padding: 2rem 0;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border-radius: 10px;
margin-bottom: 2rem;
}
.main-header h1 {
margin: 0;
font-size: 2.5rem;
font-weight: 700;
text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
}
.main-header p {
margin: 0.5rem 0 0 0;
font-size: 1.1rem;
opacity: 0.95;
}
.info-box {
background: #f0f4ff;
border-left: 4px solid #667eea;
padding: 1rem;
border-radius: 5px;
margin-bottom: 1.5rem;
}
.input-section {
background: #fafafa;
padding: 1.5rem;
border-radius: 10px;
border: 1px solid #e0e0e0;
}
.output-section {
background: white;
padding: 1.5rem;
border-radius: 10px;
border: 1px solid #e0e0e0;
box-shadow: 0 2px 8px rgba(0,0,0,0.05);
}
.generate-btn {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
color: white !important;
font-weight: 600 !important;
padding: 0.75rem 2rem !important;
border-radius: 8px !important;
border: none !important;
font-size: 1rem !important;
transition: transform 0.2s, box-shadow 0.2s !important;
}
.generate-btn:hover {
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4) !important;
}
.status-box {
font-family: 'Monaco', 'Menlo', monospace;
font-size: 0.9rem;
line-height: 1.6;
}
</style>
""")
# Header
gr.HTML("""
<div class="main-header">
<h1>🎨 STARFlow</h1>
<p>Scalable Transformer Auto-Regressive Flow</p>
<p style="font-size: 0.95rem; margin-top: 0.5rem; opacity: 0.9;">
Generate high-quality images and videos from text prompts
</p>
</div>
""")
# Info box
gr.Markdown("""
<div class="info-box">
<strong>ℹ️ Note:</strong> Checkpoints are automatically downloaded from Model Hub on first use.
First generation may take 10-20 minutes for download and model loading.
</div>
""")
with gr.Tabs() as tabs:
with gr.Tab("πŸ–ΌοΈ Text-to-Image", id="image_tab"):
with gr.Row():
with gr.Column(scale=1, min_width=400):
gr.Markdown("### βš™οΈ Generation Settings")
with gr.Group():
image_prompt = gr.Textbox(
label="πŸ“ Prompt",
placeholder="a film still of a cat playing piano",
lines=4
)
image_config = gr.Textbox(
label="βš™οΈ Config Path",
value="configs/starflow_3B_t2i_256x256.yaml",
interactive=False,
visible=False # Hidden - not needed for users
)
with gr.Group():
gr.Markdown("#### 🎨 Image Settings")
image_aspect = gr.Dropdown(
label="Aspect Ratio",
choices=["1:1", "2:3", "3:2", "16:9", "9:16", "4:5", "5:4"],
value="1:1"
)
image_cfg = gr.Slider(
label="CFG Scale",
minimum=1.0,
maximum=10.0,
value=3.6,
step=0.1
)
image_seed = gr.Number(
value=999,
label="🎲 Seed",
precision=0
)
# Hidden checkpoint field
image_checkpoint = gr.Textbox(
label="Model Checkpoint Path (auto-downloaded)",
value=DEFAULT_IMAGE_CHECKPOINT,
visible=False
)
image_btn = gr.Button(
"✨ Generate Image",
variant="primary",
size="lg",
elem_classes="generate-btn"
)
with gr.Column(scale=1, min_width=500):
gr.Markdown("### 🎨 Generated Image")
image_output = gr.Image(
label="",
type="filepath",
height=500,
show_label=False
)
image_status = gr.Textbox(
label="πŸ“Š Status",
lines=15,
max_lines=50,
interactive=False,
elem_classes="status-box",
placeholder="Status messages will appear here..."
)
image_btn.click(
fn=generate_image,
inputs=[image_prompt, image_aspect, image_cfg, image_seed, image_checkpoint, image_config],
outputs=[image_output, image_status],
show_progress=True,
queue=True # Use queue to handle long-running operations
)
with gr.Tab("🎬 Text-to-Video", id="video_tab"):
with gr.Row():
with gr.Column(scale=1, min_width=400):
gr.Markdown("### βš™οΈ Generation Settings")
with gr.Group():
video_prompt = gr.Textbox(
label="πŸ“ Prompt",
placeholder="a corgi dog looks at the camera",
lines=4
)
video_config = gr.Textbox(
label="βš™οΈ Config Path",
value="configs/starflow-v_7B_t2v_caus_480p.yaml",
interactive=False,
visible=False # Hidden - not needed for users
)
with gr.Group():
gr.Markdown("#### 🎬 Video Settings")
video_aspect = gr.Dropdown(
label="Aspect Ratio",
choices=["16:9", "1:1", "4:3"],
value="16:9"
)
video_cfg = gr.Slider(
label="CFG Scale",
minimum=1.0,
maximum=10.0,
value=3.5,
step=0.1
)
video_seed = gr.Number(
value=99,
label="🎲 Seed",
precision=0
)
video_length = gr.Slider(
label="Target Length (frames)",
minimum=81,
maximum=481,
value=81,
step=80
)
with gr.Group():
gr.Markdown("#### πŸ–ΌοΈ Optional Input")
video_input_image = gr.Image(
label="Input Image (optional)",
type="filepath"
)
# Hidden checkpoint field
video_checkpoint = gr.Textbox(
label="Model Checkpoint Path (auto-downloaded)",
value=DEFAULT_VIDEO_CHECKPOINT,
visible=False
)
video_btn = gr.Button(
"✨ Generate Video",
variant="primary",
size="lg",
elem_classes="generate-btn"
)
with gr.Column(scale=1, min_width=500):
gr.Markdown("### 🎬 Generated Video")
video_output = gr.Video(
label="",
height=500,
show_label=False
)
video_status = gr.Textbox(
label="πŸ“Š Status",
lines=15,
max_lines=50,
interactive=False,
elem_classes="status-box",
placeholder="Status messages will appear here..."
)
video_btn.click(
fn=generate_video,
inputs=[video_prompt, video_aspect, video_cfg, video_seed, video_length,
video_checkpoint, video_config, video_input_image],
outputs=[video_output, video_status],
show_progress=True,
queue=True # Use queue to handle long-running operations
)
if __name__ == "__main__":
# Enable queue for long-running operations
demo.queue(default_concurrency_limit=1)
# Password protection - users don't need HF accounts!
# Change these to your desired username/password
# For multiple users, use: auth=[("user1", "pass1"), ("user2", "pass2")]
demo.launch(
server_name="0.0.0.0",
server_port=7860,
auth=("starflow", "im30"), # Change password!
share=False, # Set to True if you want public Gradio link
max_threads=1 # Limit threads for stability
)