import gradio as gr
import torch
import numpy as np
from PIL import Image, ImageDraw
import json
from tkg_dm import TKGDMPipeline
def create_canvas_image(width=512, height=512):
"""Create a blank canvas for drawing bounding boxes"""
img = Image.new('RGB', (width, height), (240, 240, 240)) # Light gray background
draw = ImageDraw.Draw(img)
# Add grid lines for better visualization
grid_size = 64
for x in range(0, width, grid_size):
draw.line([(x, 0), (x, height)], fill=(200, 200, 200), width=1)
for y in range(0, height, grid_size):
draw.line([(0, y), (width, y)], fill=(200, 200, 200), width=1)
# Add instructions
draw.text((10, 10), "Draw bounding boxes to define reserved regions", fill=(100, 100, 100))
draw.text((10, 25), "Click and drag to create boxes", fill=(100, 100, 100))
draw.text((10, 40), "Use 'Clear Boxes' to reset", fill=(100, 100, 100))
return img
def draw_boxes_on_canvas(boxes, width=512, height=512):
"""Draw bounding boxes on canvas"""
img = create_canvas_image(width, height)
draw = ImageDraw.Draw(img)
for i, (x1, y1, x2, y2) in enumerate(boxes):
# Convert normalized coordinates to pixel coordinates
px1, py1 = int(x1 * width), int(y1 * height)
px2, py2 = int(x2 * width), int(y2 * height)
# Draw bounding box
draw.rectangle([px1, py1, px2, py2], outline='red', width=3)
draw.rectangle([px1+1, py1+1, px2-1, py2-1], outline='yellow', width=2)
# Add semi-transparent fill
overlay = Image.new('RGBA', (width, height), (0, 0, 0, 0))
overlay_draw = ImageDraw.Draw(overlay)
overlay_draw.rectangle([px1, py1, px2, py2], fill=(255, 0, 0, 50))
img = Image.alpha_composite(img.convert('RGBA'), overlay).convert('RGB')
draw = ImageDraw.Draw(img)
# Add box label
label = f"Box {i+1}"
draw.text((px1+5, py1+5), label, fill='white')
draw.text((px1+4, py1+4), label, fill='black') # Shadow effect
return img
def add_bounding_box(bbox_str, x1, y1, x2, y2):
"""Add a new bounding box to the string"""
# Ensure coordinates are in correct order and valid range
x1, x2 = max(0, min(x1, x2)), min(1, max(x1, x2))
y1, y2 = max(0, min(y1, y2)), min(1, max(y1, y2))
# Check minimum size
if x2 - x1 < 0.02 or y2 - y1 < 0.02:
return bbox_str, sync_text_to_canvas(bbox_str)
new_box = f"{x1:.3f},{y1:.3f},{x2:.3f},{y2:.3f}"
if bbox_str.strip():
updated_str = bbox_str + ";" + new_box
else:
updated_str = new_box
return updated_str, sync_text_to_canvas(updated_str)
def remove_last_box(bbox_str):
"""Remove the last bounding box"""
if not bbox_str.strip():
return "", create_canvas_image()
boxes = bbox_str.split(';')
if boxes:
boxes.pop()
updated_str = ';'.join(boxes)
return updated_str, sync_text_to_canvas(updated_str)
def create_box_builder_interface():
"""Create a user-friendly box building interface"""
return """
📦 Bounding Box Builder
Define reserved regions where content generation will be suppressed. Use coordinate inputs for precision.
Instructions:
• Each box is defined by (x1, y1, x2, y2) where coordinates range from 0.0 to 1.0
• (0,0) is top-left corner, (1,1) is bottom-right corner
• Multiple boxes are separated by semicolons
• Red/yellow boxes in preview show reserved regions
💡 Tips: Start with default values (0.2,0.2,0.8,0.4) for a center box, then adjust coordinates as needed.
"""
def load_preset_boxes(preset_name):
"""Load preset bounding box configurations"""
presets = {
"center_box": "0.3,0.3,0.7,0.7",
"top_strip": "0.0,0.0,1.0,0.3",
"bottom_strip": "0.0,0.7,1.0,1.0",
"left_right": "0.0,0.2,0.3,0.8;0.7,0.2,1.0,0.8",
"corners": "0.0,0.0,0.4,0.4;0.6,0.0,1.0,0.4;0.0,0.6,0.4,1.0;0.6,0.6,1.0,1.0",
"frame": "0.0,0.0,1.0,0.2;0.0,0.8,1.0,1.0;0.0,0.2,0.2,0.8;0.8,0.2,1.0,0.8"
}
return presets.get(preset_name, "")
def extract_boxes_from_annotated_image(annotated_data):
"""Extract bounding boxes from annotated image data - placeholder for future enhancement"""
# This would be used with more advanced annotation tools
return []
def update_canvas_with_boxes(annotated_data):
"""Update canvas when boxes are drawn - placeholder for future enhancement"""
# For now, return the current canvas
return create_canvas_image(), ""
def clear_bounding_boxes():
"""Clear all bounding boxes"""
return create_canvas_image(), ""
def parse_bounding_boxes(bbox_str):
"""
Parse bounding boxes from string format
Expected format: "x1,y1,x2,y2;x1,y1,x2,y2" or empty for legacy mode
"""
if not bbox_str or not bbox_str.strip():
return None
try:
boxes = []
for box_str in bbox_str.split(';'):
if box_str.strip():
coords = [float(x.strip()) for x in box_str.split(',')]
if len(coords) == 4:
x1, y1, x2, y2 = coords
# Ensure coordinates are in [0,1] range and valid
x1, x2 = max(0, min(x1, x2)), min(1, max(x1, x2))
y1, y2 = max(0, min(y1, y2)), min(1, max(y1, y2))
boxes.append((x1, y1, x2, y2))
return boxes if boxes else None
except Exception as e:
print(f"Error parsing bounding boxes: {e}")
return None
def sync_text_to_canvas(bbox_str):
"""Sync text input to canvas visualization"""
boxes = parse_bounding_boxes(bbox_str)
if boxes:
return draw_boxes_on_canvas(boxes)
else:
return create_canvas_image()
def generate_tkg_dm_image(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str):
"""Generate image using TKG-DM or fallback demo"""
try:
# Try to use actual TKG-DM pipeline with CPU fallback
device = "cuda" if torch.cuda.is_available() else "cpu"
# Parse bounding boxes from string input
bounding_boxes = parse_bounding_boxes(bounding_boxes_str)
# Initialize pipeline with selected model type and optional custom model ID
model_id = custom_model_id.strip() if custom_model_id.strip() else None
pipeline = TKGDMPipeline(model_id=model_id, model_type=model_type, device=device)
if pipeline.pipe is not None:
# Use actual pipeline with direct latent channel control
channel_shifts = [ch0_shift, ch1_shift, ch2_shift, ch3_shift]
# Generate with TKG-DM using direct channel shifts and user controls
# Apply intensity multiplier to base shift percent
final_shift_percent = shift_percent * intensity
# Use blur sigma (0 means auto-calculate)
blur_sigma_param = None if blur_sigma == 0 else blur_sigma
# Generate with space-aware TKG-DM using bounding boxes
if not bounding_boxes:
# Default to center box if no boxes specified
bounding_boxes = [(0.3, 0.3, 0.7, 0.7)]
image = pipeline(
prompt=prompt,
channel_shifts=channel_shifts,
bounding_boxes=bounding_boxes,
target_shift_percent=final_shift_percent,
blur_sigma=blur_sigma_param,
num_inference_steps=steps,
guidance_scale=7.5
)
return image
else:
raise Exception("Pipeline not available")
except Exception as e:
print(f"Using demo mode due to: {e}")
# Fallback to demo visualization
return create_demo_visualization(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, bounding_boxes)
def create_demo_visualization(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, bounding_boxes=None):
"""Create demo visualization of TKG-DM concept"""
# Create image with background based on channel shifts
# Convert latent channel shifts to approximate RGB for visualization
approx_color = (
max(0, min(255, 128 + int(ch0_shift * 127))), # Luminance -> Red
max(0, min(255, 128 + int(ch1_shift * 127))), # Color1 -> Green
max(0, min(255, 128 + int(ch2_shift * 127))) # Color2 -> Blue
)
img = Image.new('RGB', (512, 512), approx_color)
draw = ImageDraw.Draw(img)
# Draw space-aware bounding boxes
if not bounding_boxes:
# Default to center box if none specified
bounding_boxes = [(0.3, 0.3, 0.7, 0.7)]
for i, (x1, y1, x2, y2) in enumerate(bounding_boxes):
px1, py1 = int(x1 * 512), int(y1 * 512)
px2, py2 = int(x2 * 512), int(y2 * 512)
# Draw bounding box with gradient effect
draw.rectangle([px1, py1, px2, py2], outline='yellow', width=3)
draw.rectangle([px1+2, py1+2, px2-2, py2-2], outline='orange', width=2)
# Add box label
draw.text((px1+5, py1+5), f"Box {i+1}", fill='white')
# Add text
draw.text((10, 10), f"TKG-DM Demo", fill='white')
draw.text((10, 30), f"Prompt: {prompt[:40]}...", fill='white')
draw.text((10, 480), f"Channels: [{ch0_shift:+.2f},{ch1_shift:+.2f},{ch2_shift:+.2f},{ch3_shift:+.2f}]", fill='white')
return img
# Create intuitive interface with step-by-step workflow
with gr.Blocks(title="🎨 SAWNA: Space-Aware Text-to-Image Generation", theme=gr.themes.Default()) as demo:
# Header section with workflow explanation
with gr.Row():
with gr.Column(scale=3):
gr.Markdown("""
# 🎨 SAWNA: Space-Aware Text-to-Image Generation
Create professional images with **guaranteed empty spaces** for headlines, logos, and product shots.
Perfect for advertisements, posters, and UI mockups.
""")
with gr.Column(scale=2):
gr.Markdown("""
### 🚀 Quick Start:
1. **Describe** your image in the text prompt
2. **Choose** where to keep empty (preset or custom)
3. **Adjust** colors and style (optional)
4. **Generate** with guaranteed reserved regions
""")
with gr.Row():
gr.Markdown("""
---
💡 **How it works**: SAWNA uses advanced noise manipulation to suppress content generation in your specified regions,
ensuring they remain empty for your design elements while maintaining high quality in other areas.
""")
gr.Markdown("## 🎯 Create Your Space-Aware Image")
# Main workflow section
with gr.Row():
# Left column - Input and controls
with gr.Column(scale=2):
# Step 1: Text Prompt
with gr.Group():
gr.Markdown("## 📝 Step 1: Describe Your Image")
prompt = gr.Textbox(
value="A majestic lion in a natural landscape",
label="Text Prompt",
placeholder="Describe what you want to generate...",
lines=2
)
# Step 2: Reserved Regions
with gr.Group():
gr.Markdown("## 🔲 Step 2: Define Empty Regions")
gr.Markdown("*Choose areas that must stay empty for your design elements*")
# Quick presets
with gr.Row():
preset_dropdown = gr.Dropdown(
choices=[
("None (Default Center)", "center_box"),
("Top Banner", "top_strip"),
("Bottom Banner", "bottom_strip"),
("Side Panels", "left_right"),
("Corner Logos", "corners"),
("Full Frame", "frame")
],
label="🚀 Quick Presets",
value="center_box"
)
# Manual box creation
gr.Markdown("**Or Create Custom Boxes:**")
with gr.Row():
with gr.Column(scale=1):
x1_input = gr.Number(value=0.3, minimum=0.0, maximum=1.0, step=0.01, label="Left (X1)")
x2_input = gr.Number(value=0.7, minimum=0.0, maximum=1.0, step=0.01, label="Right (X2)")
with gr.Column(scale=1):
y1_input = gr.Number(value=0.3, minimum=0.0, maximum=1.0, step=0.01, label="Top (Y1)")
y2_input = gr.Number(value=0.7, minimum=0.0, maximum=1.0, step=0.01, label="Bottom (Y2)")
with gr.Row():
add_box_btn = gr.Button("➕ Add Region", variant="primary", size="sm")
remove_box_btn = gr.Button("❌ Remove Last", variant="secondary", size="sm")
clear_btn = gr.Button("🗑️ Clear All", variant="secondary", size="sm")
# Text representation
bounding_boxes_str = gr.Textbox(
value="0.3,0.3,0.7,0.7",
label="📋 Region Coordinates",
placeholder="x1,y1,x2,y2;x1,y1,x2,y2 (auto-updated)",
lines=2,
info="Coordinates are normalized (0.0 = left/top, 1.0 = right/bottom)"
)
# Step 3: Color and Style Controls
with gr.Group():
gr.Markdown("## 🎨 Step 3: Fine-tune Colors")
gr.Markdown("*Adjust the 4 latent channels to control image colors and style*")
with gr.Row():
ch0_shift = gr.Slider(-1.0, 1.0, 0.0, label="💡 Brightness", info="Overall image brightness")
ch1_shift = gr.Slider(-1.0, 1.0, 1.0, label="🔵 Blue-Red Balance", info="Shift toward blue (+) or red (-)")
with gr.Row():
ch2_shift = gr.Slider(-1.0, 1.0, 1.0, label="🟡 Yellow-Blue Balance", info="Shift toward yellow (+) or dark blue (-)")
ch3_shift = gr.Slider(-1.0, 1.0, 0.0, label="⚪ Contrast", info="Adjust overall contrast")
# Right column - Preview and results
with gr.Column(scale=1):
# Preview section
with gr.Group():
gr.Markdown("## 👁️ Preview: Empty Regions")
bbox_preview = gr.Image(
value=create_canvas_image(),
label="Reserved Regions Visualization",
interactive=False,
type="pil"
)
gr.Markdown("*Yellow boxes show where content will be suppressed*")
# Advanced controls (collapsible)
with gr.Accordion("🎛️ Advanced Generation Settings", open=True):
with gr.Row():
with gr.Column():
gr.Markdown("### Generation Settings")
intensity = gr.Slider(0.5, 3.0, 1.0, label="Effect Intensity", info="How strongly to suppress content in empty regions")
steps = gr.Slider(10, 100, 25, label="Quality Steps", info="More steps = higher quality, slower generation")
gr.Markdown("### TKG-DM Technical Controls")
shift_percent = gr.Slider(0.01, 0.15, 0.07, step=0.005, label="🎯 Shift Percent", info="Base shift percentage for noise optimization (±7% default)")
blur_sigma = gr.Slider(0.0, 5.0, 0.0, step=0.1, label="🌫️ Blur Sigma", info="Gaussian blur for soft transitions (0 = auto)")
with gr.Column():
gr.Markdown("### Model Selection")
model_type = gr.Dropdown(
["sd1.5", "sdxl", "sd2.1"],
value="sd1.5",
label="Model Architecture",
info="SDXL for highest quality, SD1.5 for speed"
)
custom_model_id = gr.Textbox(
"",
label="Custom Model (Optional)",
placeholder="e.g., dreamlike-art/dreamlike-diffusion-1.0",
info="Use any Hugging Face Stable Diffusion model"
)
# Generation section
with gr.Row():
with gr.Column(scale=1):
generate_btn = gr.Button(
"🎨 Generate Space-Aware Image",
variant="primary",
size="lg",
elem_id="generate-btn"
)
gr.Markdown("*Click to create your image with guaranteed empty regions*")
with gr.Column(scale=3):
output_image = gr.Image(
label="✨ Generated Image",
type="pil",
height=500,
elem_id="output-image"
)
# Examples section
with gr.Accordion("📚 Example Prompts & Layouts", open=False):
gr.Markdown("""
### Try these professional design scenarios:
Click any example to load it automatically and see how SAWNA handles different layout requirements.
""")
gr.Examples(
examples=[
[
"A majestic lion in African savanna",
0.2, 0.3, 0.0, 0.0, 1.0, 25, 0.07, 0.0, "sd1.5", "",
"0.3,0.3,0.7,0.7"
],
[
"Modern cityscape with skyscrapers at sunset",
-0.1, -0.3, 0.2, 0.1, 1.2, 30, 0.08, 0.0, "sdxl", "",
"0.0,0.0,1.0,0.3"
],
[
"Vintage luxury car on mountain road",
0.1, 0.2, -0.1, -0.2, 0.9, 25, 0.06, 0.0, "sd1.5", "",
"0.0,0.7,1.0,1.0"
],
[
"Space astronaut floating in nebula",
0.0, 0.4, -0.2, 0.3, 1.1, 35, 0.09, 1.8, "sd2.1", "",
"0.0,0.2,0.3,0.8;0.7,0.2,1.0,0.8"
],
[
"Product photography: premium watch (fine-tuned)",
0.2, 0.0, 0.1, -0.1, 1.3, 40, 0.12, 2.5, "sdxl", "",
"0.0,0.0,1.0,0.2;0.0,0.8,1.0,1.0;0.0,0.2,0.2,0.8;0.8,0.2,1.0,0.8"
]
],
inputs=[prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift,
intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str],
label="Professional Use Cases"
)
# Add custom CSS for better styling
demo.load(fn=None, js="""
function() {
// Add custom styling
const style = document.createElement('style');
style.textContent = `
.gradio-container {
max-width: 1400px !important;
margin: auto;
}
#generate-btn {
background: linear-gradient(45deg, #7c3aed, #a855f7) !important;
border: none !important;
font-weight: bold !important;
padding: 15px 30px !important;
font-size: 16px !important;
}
#output-image {
border-radius: 12px !important;
box-shadow: 0 8px 32px rgba(0,0,0,0.1) !important;
}
.gr-group {
border-radius: 12px !important;
border: 1px solid #e5e7eb !important;
padding: 20px !important;
margin-bottom: 20px !important;
}
.gr-accordion {
border-radius: 8px !important;
border: 1px solid #d1d5db !important;
}
`;
document.head.appendChild(style);
return [];
}
""")
# Event handlers
def generate_wrapper(*args):
return generate_tkg_dm_image(*args)
def clear_boxes_handler():
"""Clear boxes and update preview"""
return "", create_canvas_image()
def update_preview_from_text(bbox_str):
"""Update preview image from text input"""
return sync_text_to_canvas(bbox_str)
def add_box_handler(bbox_str, x1, y1, x2, y2):
"""Add a new box and update preview"""
updated_str, preview_img = add_bounding_box(bbox_str, x1, y1, x2, y2)
return updated_str, preview_img
def remove_box_handler(bbox_str):
"""Remove last box and update preview"""
return remove_last_box(bbox_str)
def load_preset_handler(preset_name):
"""Load preset boxes and update preview"""
if preset_name and preset_name != "center_box": # Don't reload default
preset_str = load_preset_boxes(preset_name)
return preset_str, sync_text_to_canvas(preset_str)
elif preset_name == "center_box":
preset_str = "0.3,0.3,0.7,0.7"
return preset_str, sync_text_to_canvas(preset_str)
return "", create_canvas_image()
# Preset dropdown
preset_dropdown.change(
fn=load_preset_handler,
inputs=[preset_dropdown],
outputs=[bounding_boxes_str, bbox_preview]
)
# Add box button
add_box_btn.click(
fn=add_box_handler,
inputs=[bounding_boxes_str, x1_input, y1_input, x2_input, y2_input],
outputs=[bounding_boxes_str, bbox_preview]
)
# Remove last box button
remove_box_btn.click(
fn=remove_box_handler,
inputs=[bounding_boxes_str],
outputs=[bounding_boxes_str, bbox_preview]
)
# Clear all boxes button
clear_btn.click(
fn=clear_boxes_handler,
outputs=[bounding_boxes_str, bbox_preview]
)
# Sync text to preview canvas
bounding_boxes_str.change(
fn=update_preview_from_text,
inputs=[bounding_boxes_str],
outputs=[bbox_preview]
)
# Generate button
generate_btn.click(
fn=generate_wrapper,
inputs=[prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift,
intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str],
outputs=[output_image]
)
if __name__ == "__main__":
demo.launch(share=True)