import gradio as gr import torch import numpy as np from PIL import Image, ImageDraw import json from tkg_dm import TKGDMPipeline def create_canvas_image(width=512, height=512): """Create a blank canvas for drawing bounding boxes""" img = Image.new('RGB', (width, height), (240, 240, 240)) # Light gray background draw = ImageDraw.Draw(img) # Add grid lines for better visualization grid_size = 64 for x in range(0, width, grid_size): draw.line([(x, 0), (x, height)], fill=(200, 200, 200), width=1) for y in range(0, height, grid_size): draw.line([(0, y), (width, y)], fill=(200, 200, 200), width=1) # Add instructions draw.text((10, 10), "Draw bounding boxes to define reserved regions", fill=(100, 100, 100)) draw.text((10, 25), "Click and drag to create boxes", fill=(100, 100, 100)) draw.text((10, 40), "Use 'Clear Boxes' to reset", fill=(100, 100, 100)) return img def draw_boxes_on_canvas(boxes, width=512, height=512): """Draw bounding boxes on canvas""" img = create_canvas_image(width, height) draw = ImageDraw.Draw(img) for i, (x1, y1, x2, y2) in enumerate(boxes): # Convert normalized coordinates to pixel coordinates px1, py1 = int(x1 * width), int(y1 * height) px2, py2 = int(x2 * width), int(y2 * height) # Draw bounding box draw.rectangle([px1, py1, px2, py2], outline='red', width=3) draw.rectangle([px1+1, py1+1, px2-1, py2-1], outline='yellow', width=2) # Add semi-transparent fill overlay = Image.new('RGBA', (width, height), (0, 0, 0, 0)) overlay_draw = ImageDraw.Draw(overlay) overlay_draw.rectangle([px1, py1, px2, py2], fill=(255, 0, 0, 50)) img = Image.alpha_composite(img.convert('RGBA'), overlay).convert('RGB') draw = ImageDraw.Draw(img) # Add box label label = f"Box {i+1}" draw.text((px1+5, py1+5), label, fill='white') draw.text((px1+4, py1+4), label, fill='black') # Shadow effect return img def add_bounding_box(bbox_str, x1, y1, x2, y2): """Add a new bounding box to the string""" # Ensure coordinates are in correct order and valid range x1, x2 = max(0, min(x1, x2)), min(1, max(x1, x2)) y1, y2 = max(0, min(y1, y2)), min(1, max(y1, y2)) # Check minimum size if x2 - x1 < 0.02 or y2 - y1 < 0.02: return bbox_str, sync_text_to_canvas(bbox_str) new_box = f"{x1:.3f},{y1:.3f},{x2:.3f},{y2:.3f}" if bbox_str.strip(): updated_str = bbox_str + ";" + new_box else: updated_str = new_box return updated_str, sync_text_to_canvas(updated_str) def remove_last_box(bbox_str): """Remove the last bounding box""" if not bbox_str.strip(): return "", create_canvas_image() boxes = bbox_str.split(';') if boxes: boxes.pop() updated_str = ';'.join(boxes) return updated_str, sync_text_to_canvas(updated_str) def create_box_builder_interface(): """Create a user-friendly box building interface""" return """

📦 Bounding Box Builder

Define reserved regions where content generation will be suppressed. Use coordinate inputs for precision.

Instructions:
• Each box is defined by (x1, y1, x2, y2) where coordinates range from 0.0 to 1.0
• (0,0) is top-left corner, (1,1) is bottom-right corner
• Multiple boxes are separated by semicolons
• Red/yellow boxes in preview show reserved regions
💡 Tips: Start with default values (0.2,0.2,0.8,0.4) for a center box, then adjust coordinates as needed.
""" def load_preset_boxes(preset_name): """Load preset bounding box configurations""" presets = { "center_box": "0.3,0.3,0.7,0.7", "top_strip": "0.0,0.0,1.0,0.3", "bottom_strip": "0.0,0.7,1.0,1.0", "left_right": "0.0,0.2,0.3,0.8;0.7,0.2,1.0,0.8", "corners": "0.0,0.0,0.4,0.4;0.6,0.0,1.0,0.4;0.0,0.6,0.4,1.0;0.6,0.6,1.0,1.0", "frame": "0.0,0.0,1.0,0.2;0.0,0.8,1.0,1.0;0.0,0.2,0.2,0.8;0.8,0.2,1.0,0.8" } return presets.get(preset_name, "") def extract_boxes_from_annotated_image(annotated_data): """Extract bounding boxes from annotated image data - placeholder for future enhancement""" # This would be used with more advanced annotation tools return [] def update_canvas_with_boxes(annotated_data): """Update canvas when boxes are drawn - placeholder for future enhancement""" # For now, return the current canvas return create_canvas_image(), "" def clear_bounding_boxes(): """Clear all bounding boxes""" return create_canvas_image(), "" def parse_bounding_boxes(bbox_str): """ Parse bounding boxes from string format Expected format: "x1,y1,x2,y2;x1,y1,x2,y2" or empty for legacy mode """ if not bbox_str or not bbox_str.strip(): return None try: boxes = [] for box_str in bbox_str.split(';'): if box_str.strip(): coords = [float(x.strip()) for x in box_str.split(',')] if len(coords) == 4: x1, y1, x2, y2 = coords # Ensure coordinates are in [0,1] range and valid x1, x2 = max(0, min(x1, x2)), min(1, max(x1, x2)) y1, y2 = max(0, min(y1, y2)), min(1, max(y1, y2)) boxes.append((x1, y1, x2, y2)) return boxes if boxes else None except Exception as e: print(f"Error parsing bounding boxes: {e}") return None def sync_text_to_canvas(bbox_str): """Sync text input to canvas visualization""" boxes = parse_bounding_boxes(bbox_str) if boxes: return draw_boxes_on_canvas(boxes) else: return create_canvas_image() def generate_tkg_dm_image(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str): """Generate image using TKG-DM or fallback demo""" try: # Try to use actual TKG-DM pipeline with CPU fallback device = "cuda" if torch.cuda.is_available() else "cpu" # Parse bounding boxes from string input bounding_boxes = parse_bounding_boxes(bounding_boxes_str) # Initialize pipeline with selected model type and optional custom model ID model_id = custom_model_id.strip() if custom_model_id.strip() else None pipeline = TKGDMPipeline(model_id=model_id, model_type=model_type, device=device) if pipeline.pipe is not None: # Use actual pipeline with direct latent channel control channel_shifts = [ch0_shift, ch1_shift, ch2_shift, ch3_shift] # Generate with TKG-DM using direct channel shifts and user controls # Apply intensity multiplier to base shift percent final_shift_percent = shift_percent * intensity # Use blur sigma (0 means auto-calculate) blur_sigma_param = None if blur_sigma == 0 else blur_sigma # Generate with space-aware TKG-DM using bounding boxes if not bounding_boxes: # Default to center box if no boxes specified bounding_boxes = [(0.3, 0.3, 0.7, 0.7)] image = pipeline( prompt=prompt, channel_shifts=channel_shifts, bounding_boxes=bounding_boxes, target_shift_percent=final_shift_percent, blur_sigma=blur_sigma_param, num_inference_steps=steps, guidance_scale=7.5 ) return image else: raise Exception("Pipeline not available") except Exception as e: print(f"Using demo mode due to: {e}") # Fallback to demo visualization return create_demo_visualization(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, bounding_boxes) def create_demo_visualization(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, bounding_boxes=None): """Create demo visualization of TKG-DM concept""" # Create image with background based on channel shifts # Convert latent channel shifts to approximate RGB for visualization approx_color = ( max(0, min(255, 128 + int(ch0_shift * 127))), # Luminance -> Red max(0, min(255, 128 + int(ch1_shift * 127))), # Color1 -> Green max(0, min(255, 128 + int(ch2_shift * 127))) # Color2 -> Blue ) img = Image.new('RGB', (512, 512), approx_color) draw = ImageDraw.Draw(img) # Draw space-aware bounding boxes if not bounding_boxes: # Default to center box if none specified bounding_boxes = [(0.3, 0.3, 0.7, 0.7)] for i, (x1, y1, x2, y2) in enumerate(bounding_boxes): px1, py1 = int(x1 * 512), int(y1 * 512) px2, py2 = int(x2 * 512), int(y2 * 512) # Draw bounding box with gradient effect draw.rectangle([px1, py1, px2, py2], outline='yellow', width=3) draw.rectangle([px1+2, py1+2, px2-2, py2-2], outline='orange', width=2) # Add box label draw.text((px1+5, py1+5), f"Box {i+1}", fill='white') # Add text draw.text((10, 10), f"TKG-DM Demo", fill='white') draw.text((10, 30), f"Prompt: {prompt[:40]}...", fill='white') draw.text((10, 480), f"Channels: [{ch0_shift:+.2f},{ch1_shift:+.2f},{ch2_shift:+.2f},{ch3_shift:+.2f}]", fill='white') return img # Create intuitive interface with step-by-step workflow with gr.Blocks(title="🎨 SAWNA: Space-Aware Text-to-Image Generation", theme=gr.themes.Default()) as demo: # Header section with workflow explanation with gr.Row(): with gr.Column(scale=3): gr.Markdown(""" # 🎨 SAWNA: Space-Aware Text-to-Image Generation Create professional images with **guaranteed empty spaces** for headlines, logos, and product shots. Perfect for advertisements, posters, and UI mockups. """) with gr.Column(scale=2): gr.Markdown(""" ### 🚀 Quick Start: 1. **Describe** your image in the text prompt 2. **Choose** where to keep empty (preset or custom) 3. **Adjust** colors and style (optional) 4. **Generate** with guaranteed reserved regions """) with gr.Row(): gr.Markdown(""" --- 💡 **How it works**: SAWNA uses advanced noise manipulation to suppress content generation in your specified regions, ensuring they remain empty for your design elements while maintaining high quality in other areas. """) gr.Markdown("## 🎯 Create Your Space-Aware Image") # Main workflow section with gr.Row(): # Left column - Input and controls with gr.Column(scale=2): # Step 1: Text Prompt with gr.Group(): gr.Markdown("## 📝 Step 1: Describe Your Image") prompt = gr.Textbox( value="A majestic lion in a natural landscape", label="Text Prompt", placeholder="Describe what you want to generate...", lines=2 ) # Step 2: Reserved Regions with gr.Group(): gr.Markdown("## 🔲 Step 2: Define Empty Regions") gr.Markdown("*Choose areas that must stay empty for your design elements*") # Quick presets with gr.Row(): preset_dropdown = gr.Dropdown( choices=[ ("None (Default Center)", "center_box"), ("Top Banner", "top_strip"), ("Bottom Banner", "bottom_strip"), ("Side Panels", "left_right"), ("Corner Logos", "corners"), ("Full Frame", "frame") ], label="🚀 Quick Presets", value="center_box" ) # Manual box creation gr.Markdown("**Or Create Custom Boxes:**") with gr.Row(): with gr.Column(scale=1): x1_input = gr.Number(value=0.3, minimum=0.0, maximum=1.0, step=0.01, label="Left (X1)") x2_input = gr.Number(value=0.7, minimum=0.0, maximum=1.0, step=0.01, label="Right (X2)") with gr.Column(scale=1): y1_input = gr.Number(value=0.3, minimum=0.0, maximum=1.0, step=0.01, label="Top (Y1)") y2_input = gr.Number(value=0.7, minimum=0.0, maximum=1.0, step=0.01, label="Bottom (Y2)") with gr.Row(): add_box_btn = gr.Button("➕ Add Region", variant="primary", size="sm") remove_box_btn = gr.Button("❌ Remove Last", variant="secondary", size="sm") clear_btn = gr.Button("🗑️ Clear All", variant="secondary", size="sm") # Text representation bounding_boxes_str = gr.Textbox( value="0.3,0.3,0.7,0.7", label="📋 Region Coordinates", placeholder="x1,y1,x2,y2;x1,y1,x2,y2 (auto-updated)", lines=2, info="Coordinates are normalized (0.0 = left/top, 1.0 = right/bottom)" ) # Step 3: Color and Style Controls with gr.Group(): gr.Markdown("## 🎨 Step 3: Fine-tune Colors") gr.Markdown("*Adjust the 4 latent channels to control image colors and style*") with gr.Row(): ch0_shift = gr.Slider(-1.0, 1.0, 0.0, label="💡 Brightness", info="Overall image brightness") ch1_shift = gr.Slider(-1.0, 1.0, 1.0, label="🔵 Blue-Red Balance", info="Shift toward blue (+) or red (-)") with gr.Row(): ch2_shift = gr.Slider(-1.0, 1.0, 1.0, label="🟡 Yellow-Blue Balance", info="Shift toward yellow (+) or dark blue (-)") ch3_shift = gr.Slider(-1.0, 1.0, 0.0, label="⚪ Contrast", info="Adjust overall contrast") # Right column - Preview and results with gr.Column(scale=1): # Preview section with gr.Group(): gr.Markdown("## 👁️ Preview: Empty Regions") bbox_preview = gr.Image( value=create_canvas_image(), label="Reserved Regions Visualization", interactive=False, type="pil" ) gr.Markdown("*Yellow boxes show where content will be suppressed*") # Advanced controls (collapsible) with gr.Accordion("🎛️ Advanced Generation Settings", open=True): with gr.Row(): with gr.Column(): gr.Markdown("### Generation Settings") intensity = gr.Slider(0.5, 3.0, 1.0, label="Effect Intensity", info="How strongly to suppress content in empty regions") steps = gr.Slider(10, 100, 25, label="Quality Steps", info="More steps = higher quality, slower generation") gr.Markdown("### TKG-DM Technical Controls") shift_percent = gr.Slider(0.01, 0.15, 0.07, step=0.005, label="🎯 Shift Percent", info="Base shift percentage for noise optimization (±7% default)") blur_sigma = gr.Slider(0.0, 5.0, 0.0, step=0.1, label="🌫️ Blur Sigma", info="Gaussian blur for soft transitions (0 = auto)") with gr.Column(): gr.Markdown("### Model Selection") model_type = gr.Dropdown( ["sd1.5", "sdxl", "sd2.1"], value="sd1.5", label="Model Architecture", info="SDXL for highest quality, SD1.5 for speed" ) custom_model_id = gr.Textbox( "", label="Custom Model (Optional)", placeholder="e.g., dreamlike-art/dreamlike-diffusion-1.0", info="Use any Hugging Face Stable Diffusion model" ) # Generation section with gr.Row(): with gr.Column(scale=1): generate_btn = gr.Button( "🎨 Generate Space-Aware Image", variant="primary", size="lg", elem_id="generate-btn" ) gr.Markdown("*Click to create your image with guaranteed empty regions*") with gr.Column(scale=3): output_image = gr.Image( label="✨ Generated Image", type="pil", height=500, elem_id="output-image" ) # Examples section with gr.Accordion("📚 Example Prompts & Layouts", open=False): gr.Markdown(""" ### Try these professional design scenarios: Click any example to load it automatically and see how SAWNA handles different layout requirements. """) gr.Examples( examples=[ [ "A majestic lion in African savanna", 0.2, 0.3, 0.0, 0.0, 1.0, 25, 0.07, 0.0, "sd1.5", "", "0.3,0.3,0.7,0.7" ], [ "Modern cityscape with skyscrapers at sunset", -0.1, -0.3, 0.2, 0.1, 1.2, 30, 0.08, 0.0, "sdxl", "", "0.0,0.0,1.0,0.3" ], [ "Vintage luxury car on mountain road", 0.1, 0.2, -0.1, -0.2, 0.9, 25, 0.06, 0.0, "sd1.5", "", "0.0,0.7,1.0,1.0" ], [ "Space astronaut floating in nebula", 0.0, 0.4, -0.2, 0.3, 1.1, 35, 0.09, 1.8, "sd2.1", "", "0.0,0.2,0.3,0.8;0.7,0.2,1.0,0.8" ], [ "Product photography: premium watch (fine-tuned)", 0.2, 0.0, 0.1, -0.1, 1.3, 40, 0.12, 2.5, "sdxl", "", "0.0,0.0,1.0,0.2;0.0,0.8,1.0,1.0;0.0,0.2,0.2,0.8;0.8,0.2,1.0,0.8" ] ], inputs=[prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str], label="Professional Use Cases" ) # Add custom CSS for better styling demo.load(fn=None, js=""" function() { // Add custom styling const style = document.createElement('style'); style.textContent = ` .gradio-container { max-width: 1400px !important; margin: auto; } #generate-btn { background: linear-gradient(45deg, #7c3aed, #a855f7) !important; border: none !important; font-weight: bold !important; padding: 15px 30px !important; font-size: 16px !important; } #output-image { border-radius: 12px !important; box-shadow: 0 8px 32px rgba(0,0,0,0.1) !important; } .gr-group { border-radius: 12px !important; border: 1px solid #e5e7eb !important; padding: 20px !important; margin-bottom: 20px !important; } .gr-accordion { border-radius: 8px !important; border: 1px solid #d1d5db !important; } `; document.head.appendChild(style); return []; } """) # Event handlers def generate_wrapper(*args): return generate_tkg_dm_image(*args) def clear_boxes_handler(): """Clear boxes and update preview""" return "", create_canvas_image() def update_preview_from_text(bbox_str): """Update preview image from text input""" return sync_text_to_canvas(bbox_str) def add_box_handler(bbox_str, x1, y1, x2, y2): """Add a new box and update preview""" updated_str, preview_img = add_bounding_box(bbox_str, x1, y1, x2, y2) return updated_str, preview_img def remove_box_handler(bbox_str): """Remove last box and update preview""" return remove_last_box(bbox_str) def load_preset_handler(preset_name): """Load preset boxes and update preview""" if preset_name and preset_name != "center_box": # Don't reload default preset_str = load_preset_boxes(preset_name) return preset_str, sync_text_to_canvas(preset_str) elif preset_name == "center_box": preset_str = "0.3,0.3,0.7,0.7" return preset_str, sync_text_to_canvas(preset_str) return "", create_canvas_image() # Preset dropdown preset_dropdown.change( fn=load_preset_handler, inputs=[preset_dropdown], outputs=[bounding_boxes_str, bbox_preview] ) # Add box button add_box_btn.click( fn=add_box_handler, inputs=[bounding_boxes_str, x1_input, y1_input, x2_input, y2_input], outputs=[bounding_boxes_str, bbox_preview] ) # Remove last box button remove_box_btn.click( fn=remove_box_handler, inputs=[bounding_boxes_str], outputs=[bounding_boxes_str, bbox_preview] ) # Clear all boxes button clear_btn.click( fn=clear_boxes_handler, outputs=[bounding_boxes_str, bbox_preview] ) # Sync text to preview canvas bounding_boxes_str.change( fn=update_preview_from_text, inputs=[bounding_boxes_str], outputs=[bbox_preview] ) # Generate button generate_btn.click( fn=generate_wrapper, inputs=[prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str], outputs=[output_image] ) if __name__ == "__main__": demo.launch(share=True)