File size: 23,440 Bytes
c8db08b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a002cf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
import gradio as gr
import torch
import numpy as np
from PIL import Image, ImageDraw
import json
from tkg_dm import TKGDMPipeline


def create_canvas_image(width=512, height=512):
    """Create a blank canvas for drawing bounding boxes"""
    img = Image.new('RGB', (width, height), (240, 240, 240))  # Light gray background
    draw = ImageDraw.Draw(img)
    
    # Add grid lines for better visualization
    grid_size = 64
    for x in range(0, width, grid_size):
        draw.line([(x, 0), (x, height)], fill=(200, 200, 200), width=1)
    for y in range(0, height, grid_size):
        draw.line([(0, y), (width, y)], fill=(200, 200, 200), width=1)
    
    # Add instructions
    draw.text((10, 10), "Draw bounding boxes to define reserved regions", fill=(100, 100, 100))
    draw.text((10, 25), "Click and drag to create boxes", fill=(100, 100, 100))
    draw.text((10, 40), "Use 'Clear Boxes' to reset", fill=(100, 100, 100))
    
    return img

def draw_boxes_on_canvas(boxes, width=512, height=512):
    """Draw bounding boxes on canvas"""
    img = create_canvas_image(width, height)
    draw = ImageDraw.Draw(img)
    
    for i, (x1, y1, x2, y2) in enumerate(boxes):
        # Convert normalized coordinates to pixel coordinates
        px1, py1 = int(x1 * width), int(y1 * height)
        px2, py2 = int(x2 * width), int(y2 * height)
        
        # Draw bounding box
        draw.rectangle([px1, py1, px2, py2], outline='red', width=3)
        draw.rectangle([px1+1, py1+1, px2-1, py2-1], outline='yellow', width=2)
        
        # Add semi-transparent fill
        overlay = Image.new('RGBA', (width, height), (0, 0, 0, 0))
        overlay_draw = ImageDraw.Draw(overlay)
        overlay_draw.rectangle([px1, py1, px2, py2], fill=(255, 0, 0, 50))
        img = Image.alpha_composite(img.convert('RGBA'), overlay).convert('RGB')
        draw = ImageDraw.Draw(img)
        
        # Add box label
        label = f"Box {i+1}"
        draw.text((px1+5, py1+5), label, fill='white')
        draw.text((px1+4, py1+4), label, fill='black')  # Shadow effect
    
    return img

def add_bounding_box(bbox_str, x1, y1, x2, y2):
    """Add a new bounding box to the string"""
    # Ensure coordinates are in correct order and valid range
    x1, x2 = max(0, min(x1, x2)), min(1, max(x1, x2))
    y1, y2 = max(0, min(y1, y2)), min(1, max(y1, y2))
    
    # Check minimum size
    if x2 - x1 < 0.02 or y2 - y1 < 0.02:
        return bbox_str, sync_text_to_canvas(bbox_str)
    
    new_box = f"{x1:.3f},{y1:.3f},{x2:.3f},{y2:.3f}"
    
    if bbox_str.strip():
        updated_str = bbox_str + ";" + new_box
    else:
        updated_str = new_box
        
    return updated_str, sync_text_to_canvas(updated_str)

def remove_last_box(bbox_str):
    """Remove the last bounding box"""
    if not bbox_str.strip():
        return "", create_canvas_image()
    
    boxes = bbox_str.split(';')
    if boxes:
        boxes.pop()
    
    updated_str = ';'.join(boxes)
    return updated_str, sync_text_to_canvas(updated_str)

def create_box_builder_interface():
    """Create a user-friendly box building interface"""
    return """
    <div style="background: #f8f9fa; padding: 20px; border-radius: 8px; border: 1px solid #dee2e6;">
        <h4 style="margin-top: 0; color: #495057;">πŸ“¦ Bounding Box Builder</h4>
        <p style="color: #6c757d; margin-bottom: 15px;">
            Define reserved regions where content generation will be suppressed. Use coordinate inputs for precision.
        </p>
        <div style="background: white; padding: 15px; border-radius: 6px; border: 1px solid #ced4da; margin-bottom: 15px;">
            <strong>Instructions:</strong><br>
            β€’ Each box is defined by (x1, y1, x2, y2) where coordinates range from 0.0 to 1.0<br>
            β€’ (0,0) is top-left corner, (1,1) is bottom-right corner<br>
            β€’ Multiple boxes are separated by semicolons<br>
            β€’ Red/yellow boxes in preview show reserved regions
        </div>
        <div style="background: #e7f3ff; padding: 10px; border-radius: 6px; border: 1px solid #b3d9ff;">
            <strong>πŸ’‘ Tips:</strong> Start with default values (0.2,0.2,0.8,0.4) for a center box, then adjust coordinates as needed.
        </div>
    </div>
    """

def load_preset_boxes(preset_name):
    """Load preset bounding box configurations"""
    presets = {
        "center_box": "0.3,0.3,0.7,0.7",
        "top_strip": "0.0,0.0,1.0,0.3",
        "bottom_strip": "0.0,0.7,1.0,1.0",
        "left_right": "0.0,0.2,0.3,0.8;0.7,0.2,1.0,0.8", 
        "corners": "0.0,0.0,0.4,0.4;0.6,0.0,1.0,0.4;0.0,0.6,0.4,1.0;0.6,0.6,1.0,1.0",
        "frame": "0.0,0.0,1.0,0.2;0.0,0.8,1.0,1.0;0.0,0.2,0.2,0.8;0.8,0.2,1.0,0.8"
    }
    return presets.get(preset_name, "")

def extract_boxes_from_annotated_image(annotated_data):
    """Extract bounding boxes from annotated image data - placeholder for future enhancement"""
    # This would be used with more advanced annotation tools
    return []

def update_canvas_with_boxes(annotated_data):
    """Update canvas when boxes are drawn - placeholder for future enhancement"""
    # For now, return the current canvas
    return create_canvas_image(), ""

def clear_bounding_boxes():
    """Clear all bounding boxes"""
    return create_canvas_image(), ""

def parse_bounding_boxes(bbox_str):
    """
    Parse bounding boxes from string format
    Expected format: "x1,y1,x2,y2;x1,y1,x2,y2" or empty for legacy mode
    """
    if not bbox_str or not bbox_str.strip():
        return None
    
    try:
        boxes = []
        for box_str in bbox_str.split(';'):
            if box_str.strip():
                coords = [float(x.strip()) for x in box_str.split(',')]
                if len(coords) == 4:
                    x1, y1, x2, y2 = coords
                    # Ensure coordinates are in [0,1] range and valid
                    x1, x2 = max(0, min(x1, x2)), min(1, max(x1, x2))
                    y1, y2 = max(0, min(y1, y2)), min(1, max(y1, y2))
                    boxes.append((x1, y1, x2, y2))
        
        return boxes if boxes else None
    except Exception as e:
        print(f"Error parsing bounding boxes: {e}")
        return None

def sync_text_to_canvas(bbox_str):
    """Sync text input to canvas visualization"""
    boxes = parse_bounding_boxes(bbox_str)
    if boxes:
        return draw_boxes_on_canvas(boxes)
    else:
        return create_canvas_image()


def generate_tkg_dm_image(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str):
    """Generate image using TKG-DM or fallback demo"""
    
    try:
        # Try to use actual TKG-DM pipeline with CPU fallback
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Parse bounding boxes from string input
        bounding_boxes = parse_bounding_boxes(bounding_boxes_str)
        
        # Initialize pipeline with selected model type and optional custom model ID
        model_id = custom_model_id.strip() if custom_model_id.strip() else None
        pipeline = TKGDMPipeline(model_id=model_id, model_type=model_type, device=device)
        
        if pipeline.pipe is not None:
            # Use actual pipeline with direct latent channel control
            channel_shifts = [ch0_shift, ch1_shift, ch2_shift, ch3_shift]
            
            # Generate with TKG-DM using direct channel shifts and user controls
            # Apply intensity multiplier to base shift percent
            final_shift_percent = shift_percent * intensity
            
            # Use blur sigma (0 means auto-calculate)
            blur_sigma_param = None if blur_sigma == 0 else blur_sigma
            
            # Generate with space-aware TKG-DM using bounding boxes
            if not bounding_boxes:
                # Default to center box if no boxes specified
                bounding_boxes = [(0.3, 0.3, 0.7, 0.7)]
            
            image = pipeline(
                prompt=prompt,
                channel_shifts=channel_shifts,
                bounding_boxes=bounding_boxes,
                target_shift_percent=final_shift_percent,
                blur_sigma=blur_sigma_param,
                num_inference_steps=steps,
                guidance_scale=7.5
            )
            return image
        else:
            raise Exception("Pipeline not available")
            
    except Exception as e:
        print(f"Using demo mode due to: {e}")
        # Fallback to demo visualization
        return create_demo_visualization(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, bounding_boxes)


def create_demo_visualization(prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, bounding_boxes=None):
    """Create demo visualization of TKG-DM concept"""
    
    # Create image with background based on channel shifts
    # Convert latent channel shifts to approximate RGB for visualization
    approx_color = (
        max(0, min(255, 128 + int(ch0_shift * 127))),  # Luminance -> Red
        max(0, min(255, 128 + int(ch1_shift * 127))),  # Color1 -> Green  
        max(0, min(255, 128 + int(ch2_shift * 127)))   # Color2 -> Blue
    )
    img = Image.new('RGB', (512, 512), approx_color)
    draw = ImageDraw.Draw(img)
    
    # Draw space-aware bounding boxes
    if not bounding_boxes:
        # Default to center box if none specified
        bounding_boxes = [(0.3, 0.3, 0.7, 0.7)]
    
    for i, (x1, y1, x2, y2) in enumerate(bounding_boxes):
        px1, py1 = int(x1 * 512), int(y1 * 512)
        px2, py2 = int(x2 * 512), int(y2 * 512)
        
        # Draw bounding box with gradient effect
        draw.rectangle([px1, py1, px2, py2], outline='yellow', width=3)
        draw.rectangle([px1+2, py1+2, px2-2, py2-2], outline='orange', width=2)
        
        # Add box label
        draw.text((px1+5, py1+5), f"Box {i+1}", fill='white')
    
    # Add text
    draw.text((10, 10), f"TKG-DM Demo", fill='white')
    draw.text((10, 30), f"Prompt: {prompt[:40]}...", fill='white')
    draw.text((10, 480), f"Channels: [{ch0_shift:+.2f},{ch1_shift:+.2f},{ch2_shift:+.2f},{ch3_shift:+.2f}]", fill='white')
    
    return img


# Create intuitive interface with step-by-step workflow
with gr.Blocks(title="🎨 SAWNA: Space-Aware Text-to-Image Generation", theme=gr.themes.Default()) as demo:
    
    # Header section with workflow explanation
    with gr.Row():
        with gr.Column(scale=3):
            gr.Markdown("""
            # 🎨 SAWNA: Space-Aware Text-to-Image Generation
            
            Create professional images with **guaranteed empty spaces** for headlines, logos, and product shots.
            Perfect for advertisements, posters, and UI mockups.
            """)
        with gr.Column(scale=2):
            gr.Markdown("""
            ### πŸš€ Quick Start:
            1. **Describe** your image in the text prompt
            2. **Choose** where to keep empty (preset or custom)  
            3. **Adjust** colors and style (optional)
            4. **Generate** with guaranteed reserved regions
            """)
    
    with gr.Row():
        gr.Markdown("""
        ---
        πŸ’‘ **How it works**: SAWNA uses advanced noise manipulation to suppress content generation in your specified regions, 
        ensuring they remain empty for your design elements while maintaining high quality in other areas.
        """)
    
    gr.Markdown("## 🎯 Create Your Space-Aware Image")
    
    # Main workflow section
    with gr.Row():
        # Left column - Input and controls
        with gr.Column(scale=2):
            
            # Step 1: Text Prompt
            with gr.Group():
                gr.Markdown("## πŸ“ Step 1: Describe Your Image")
                prompt = gr.Textbox(
                    value="A majestic lion in a natural landscape", 
                    label="Text Prompt",
                    placeholder="Describe what you want to generate...",
                    lines=2
                )
                
            # Step 2: Reserved Regions  
            with gr.Group():
                gr.Markdown("## πŸ”² Step 2: Define Empty Regions")
                gr.Markdown("*Choose areas that must stay empty for your design elements*")
                
                # Quick presets
                with gr.Row():
                    preset_dropdown = gr.Dropdown(
                        choices=[
                            ("None (Default Center)", "center_box"),
                            ("Top Banner", "top_strip"),
                            ("Bottom Banner", "bottom_strip"), 
                            ("Side Panels", "left_right"),
                            ("Corner Logos", "corners"),
                            ("Full Frame", "frame")
                        ],
                        label="πŸš€ Quick Presets",
                        value="center_box"
                    )
                
                # Manual box creation
                gr.Markdown("**Or Create Custom Boxes:**")
                with gr.Row():
                    with gr.Column(scale=1):
                        x1_input = gr.Number(value=0.3, minimum=0.0, maximum=1.0, step=0.01, label="Left (X1)")
                        x2_input = gr.Number(value=0.7, minimum=0.0, maximum=1.0, step=0.01, label="Right (X2)")
                    with gr.Column(scale=1):
                        y1_input = gr.Number(value=0.3, minimum=0.0, maximum=1.0, step=0.01, label="Top (Y1)")
                        y2_input = gr.Number(value=0.7, minimum=0.0, maximum=1.0, step=0.01, label="Bottom (Y2)")
                
                with gr.Row():
                    add_box_btn = gr.Button("βž• Add Region", variant="primary", size="sm")
                    remove_box_btn = gr.Button("❌ Remove Last", variant="secondary", size="sm")
                    clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary", size="sm")
                
                # Text representation
                bounding_boxes_str = gr.Textbox(
                    value="0.3,0.3,0.7,0.7",
                    label="πŸ“‹ Region Coordinates",
                    placeholder="x1,y1,x2,y2;x1,y1,x2,y2 (auto-updated)",
                    lines=2,
                    info="Coordinates are normalized (0.0 = left/top, 1.0 = right/bottom)"
                )
                
            # Step 3: Color and Style Controls
            with gr.Group():
                gr.Markdown("## 🎨 Step 3: Fine-tune Colors")
                gr.Markdown("*Adjust the 4 latent channels to control image colors and style*")
                
                with gr.Row():
                    ch0_shift = gr.Slider(-1.0, 1.0, 0.0, label="πŸ’‘ Brightness", info="Overall image brightness")
                    ch1_shift = gr.Slider(-1.0, 1.0, 1.0, label="πŸ”΅ Blue-Red Balance", info="Shift toward blue (+) or red (-)")
                    
                with gr.Row():
                    ch2_shift = gr.Slider(-1.0, 1.0, 1.0, label="🟑 Yellow-Blue Balance", info="Shift toward yellow (+) or dark blue (-)")
                    ch3_shift = gr.Slider(-1.0, 1.0, 0.0, label="βšͺ Contrast", info="Adjust overall contrast")

        # Right column - Preview and results
        with gr.Column(scale=1):
            
            # Preview section
            with gr.Group():
                gr.Markdown("## πŸ‘οΈ Preview: Empty Regions")
                bbox_preview = gr.Image(
                    value=create_canvas_image(), 
                    label="Reserved Regions Visualization",
                    interactive=False, 
                    type="pil"
                )
                gr.Markdown("*Yellow boxes show where content will be suppressed*")
            
    # Advanced controls (collapsible)
    with gr.Accordion("πŸŽ›οΈ Advanced Generation Settings", open=True):
        with gr.Row():
            with gr.Column():
                gr.Markdown("### Generation Settings")
                intensity = gr.Slider(0.5, 3.0, 1.0, label="Effect Intensity", info="How strongly to suppress content in empty regions")
                steps = gr.Slider(10, 100, 25, label="Quality Steps", info="More steps = higher quality, slower generation")
                
                gr.Markdown("### TKG-DM Technical Controls")
                shift_percent = gr.Slider(0.01, 0.15, 0.07, step=0.005, label="🎯 Shift Percent", info="Base shift percentage for noise optimization (±7% default)")
                blur_sigma = gr.Slider(0.0, 5.0, 0.0, step=0.1, label="🌫️ Blur Sigma", info="Gaussian blur for soft transitions (0 = auto)")
                
            with gr.Column():
                gr.Markdown("### Model Selection")
                model_type = gr.Dropdown(
                    ["sd1.5", "sdxl", "sd2.1"], 
                    value="sd1.5", 
                    label="Model Architecture",
                    info="SDXL for highest quality, SD1.5 for speed"
                )
                custom_model_id = gr.Textbox(
                    "", 
                    label="Custom Model (Optional)", 
                    placeholder="e.g., dreamlike-art/dreamlike-diffusion-1.0",
                    info="Use any Hugging Face Stable Diffusion model"
                )
    
    # Generation section
    with gr.Row():
        with gr.Column(scale=1):
            generate_btn = gr.Button(
                "🎨 Generate Space-Aware Image", 
                variant="primary", 
                size="lg",
                elem_id="generate-btn"
            )
            gr.Markdown("*Click to create your image with guaranteed empty regions*")
            
        with gr.Column(scale=3):
            output_image = gr.Image(
                label="✨ Generated Image", 
                type="pil",
                height=500,
                elem_id="output-image"
            )
    
    # Examples section
    with gr.Accordion("πŸ“š Example Prompts & Layouts", open=False):
        gr.Markdown("""
        ### Try these professional design scenarios:
        Click any example to load it automatically and see how SAWNA handles different layout requirements.
        """)
        
        gr.Examples(
            examples=[
                [
                    "A majestic lion in African savanna", 
                    0.2, 0.3, 0.0, 0.0, 1.0, 25, 0.07, 0.0, "sd1.5", "", 
                    "0.3,0.3,0.7,0.7"
                ],
                [
                    "Modern cityscape with skyscrapers at sunset", 
                    -0.1, -0.3, 0.2, 0.1, 1.2, 30, 0.08, 0.0, "sdxl", "", 
                    "0.0,0.0,1.0,0.3"
                ],
                [
                    "Vintage luxury car on mountain road", 
                    0.1, 0.2, -0.1, -0.2, 0.9, 25, 0.06, 0.0, "sd1.5", "", 
                    "0.0,0.7,1.0,1.0"
                ],
                [
                    "Space astronaut floating in nebula", 
                    0.0, 0.4, -0.2, 0.3, 1.1, 35, 0.09, 1.8, "sd2.1", "", 
                    "0.0,0.2,0.3,0.8;0.7,0.2,1.0,0.8"
                ],
                [
                    "Product photography: premium watch (fine-tuned)", 
                    0.2, 0.0, 0.1, -0.1, 1.3, 40, 0.12, 2.5, "sdxl", "", 
                    "0.0,0.0,1.0,0.2;0.0,0.8,1.0,1.0;0.0,0.2,0.2,0.8;0.8,0.2,1.0,0.8"
                ]
            ],
            inputs=[prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift, 
                    intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str],
            label="Professional Use Cases"
        )
    
    # Add custom CSS for better styling
    demo.load(fn=None, js="""
    function() {
        // Add custom styling
        const style = document.createElement('style');
        style.textContent = `
            .gradio-container {
                max-width: 1400px !important;
                margin: auto;
            }
            
            #generate-btn {
                background: linear-gradient(45deg, #7c3aed, #a855f7) !important;
                border: none !important;
                font-weight: bold !important;
                padding: 15px 30px !important;
                font-size: 16px !important;
            }
            
            #output-image {
                border-radius: 12px !important;
                box-shadow: 0 8px 32px rgba(0,0,0,0.1) !important;
            }
            
            .gr-group {
                border-radius: 12px !important;
                border: 1px solid #e5e7eb !important;
                padding: 20px !important;
                margin-bottom: 20px !important;
            }
            
            .gr-accordion {
                border-radius: 8px !important;
                border: 1px solid #d1d5db !important;
            }
        `;
        document.head.appendChild(style);
        return [];
    }
    """)
    
    # Event handlers
    def generate_wrapper(*args):
        return generate_tkg_dm_image(*args)
    
    def clear_boxes_handler():
        """Clear boxes and update preview"""
        return "", create_canvas_image()
    
    def update_preview_from_text(bbox_str):
        """Update preview image from text input"""
        return sync_text_to_canvas(bbox_str)
    
    def add_box_handler(bbox_str, x1, y1, x2, y2):
        """Add a new box and update preview"""
        updated_str, preview_img = add_bounding_box(bbox_str, x1, y1, x2, y2)
        return updated_str, preview_img
    
    def remove_box_handler(bbox_str):
        """Remove last box and update preview"""
        return remove_last_box(bbox_str)
    
    def load_preset_handler(preset_name):
        """Load preset boxes and update preview"""
        if preset_name and preset_name != "center_box":  # Don't reload default
            preset_str = load_preset_boxes(preset_name)
            return preset_str, sync_text_to_canvas(preset_str)
        elif preset_name == "center_box":
            preset_str = "0.3,0.3,0.7,0.7"
            return preset_str, sync_text_to_canvas(preset_str)
        return "", create_canvas_image()
    
    # Preset dropdown
    preset_dropdown.change(
        fn=load_preset_handler,
        inputs=[preset_dropdown],
        outputs=[bounding_boxes_str, bbox_preview]
    )
    
    # Add box button
    add_box_btn.click(
        fn=add_box_handler,
        inputs=[bounding_boxes_str, x1_input, y1_input, x2_input, y2_input],
        outputs=[bounding_boxes_str, bbox_preview]
    )
    
    # Remove last box button
    remove_box_btn.click(
        fn=remove_box_handler,
        inputs=[bounding_boxes_str],
        outputs=[bounding_boxes_str, bbox_preview]
    )
    
    # Clear all boxes button
    clear_btn.click(
        fn=clear_boxes_handler,
        outputs=[bounding_boxes_str, bbox_preview]
    )
    
    # Sync text to preview canvas
    bounding_boxes_str.change(
        fn=update_preview_from_text,
        inputs=[bounding_boxes_str],
        outputs=[bbox_preview]
    )
    
    # Generate button
    generate_btn.click(
        fn=generate_wrapper,
        inputs=[prompt, ch0_shift, ch1_shift, ch2_shift, ch3_shift,
                intensity, steps, shift_percent, blur_sigma, model_type, custom_model_id, bounding_boxes_str],
        outputs=[output_image]
    )


if __name__ == "__main__":
    demo.launch(share=True)