aiqtech commited on
Commit
cd76271
Β·
verified Β·
1 Parent(s): b22198b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +343 -212
app.py CHANGED
@@ -1,261 +1,335 @@
1
  import gradio as gr
2
  import numpy as np
3
- import spaces
4
  import torch
5
  import random
6
- from PIL import Image, ImageOps
7
- from diffusers import FluxKontextPipeline
8
- from diffusers.utils import load_image
 
 
 
 
9
 
10
- # Load Kontext model with Reference Pose LoRA
11
  MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
12
 
13
- # Initialize the pipeline
14
- pipe = FluxKontextPipeline.from_pretrained(
15
- "black-forest-labs/FLUX.1-Kontext-dev",
16
- torch_dtype=torch.bfloat16
17
- ).to("cuda")
18
 
19
- # Load the Reference Pose LoRA (if available)
20
- # Note: You'll need to add the actual LoRA loading code here
21
- # pipe.load_lora_weights("path/to/refcontrol-pose-lora", adapter_name="refcontrol")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- def prepare_pose_reference_pair(reference_image, pose_image):
24
  """
25
- Prepare the reference image and pose control map for Kontext processing.
26
-
27
- Args:
28
- reference_image: PIL Image - The source image with identity/style to preserve
29
- pose_image: PIL Image - The pose/line art control map
30
-
31
- Returns:
32
- PIL Image: Concatenated image with reference on left, pose on right
33
  """
34
  if reference_image is None or pose_image is None:
35
  return None
36
 
37
- # Convert images to RGB
38
  reference_image = reference_image.convert("RGB")
39
  pose_image = pose_image.convert("RGB")
40
 
41
- # Resize images to have the same height for better concatenation
42
- target_height = 768 # Standard height for Flux
43
-
44
- # Calculate proportional widths
45
  ref_ratio = reference_image.width / reference_image.height
46
  pose_ratio = pose_image.width / pose_image.height
47
 
48
- ref_width = int(target_height * ref_ratio)
49
- pose_width = int(target_height * pose_ratio)
 
 
50
 
51
- # Ensure dimensions are divisible by 8 (required for Flux)
52
  ref_width = (ref_width // 8) * 8
53
  pose_width = (pose_width // 8) * 8
 
54
 
55
  # Resize images
56
- reference_resized = reference_image.resize((ref_width, target_height), Image.LANCZOS)
57
- pose_resized = pose_image.resize((pose_width, target_height), Image.LANCZOS)
58
 
59
- # Create concatenated image: reference on left, pose on right
60
  total_width = ref_width + pose_width
61
- concatenated = Image.new('RGB', (total_width, target_height), (255, 255, 255))
62
-
63
- # Paste images
64
  concatenated.paste(reference_resized, (0, 0))
65
  concatenated.paste(pose_resized, (ref_width, 0))
66
 
67
- return concatenated, ref_width, pose_width
68
 
69
- def process_pose_image(pose_image):
70
  """
71
- Process the pose image to enhance line art visibility if needed.
72
  """
73
- if pose_image is None:
74
  return None
75
 
76
- pose_image = pose_image.convert("RGB")
 
 
 
 
 
 
 
 
 
 
77
 
78
- # Optional: Enhance contrast for better pose detection
79
- # You can add image processing here if the pose needs enhancement
80
 
81
- return pose_image
 
82
 
83
- @spaces.GPU
84
- def infer_pose_transfer(
85
- reference_image,
86
- pose_image,
87
- prompt="",
88
- seed=42,
89
- randomize_seed=False,
 
90
  guidance_scale=3.5,
91
- strength=0.85,
 
 
92
  progress=gr.Progress(track_tqdm=True)
93
  ):
94
  """
95
- Transfer pose from control image to reference image using Flux Kontext.
96
  """
97
 
 
 
 
98
  if reference_image is None or pose_image is None:
99
- raise gr.Error("Please upload both a reference image and a pose image.")
100
 
 
101
  if randomize_seed:
102
  seed = random.randint(0, MAX_SEED)
103
 
104
- # Process pose image if needed
105
- pose_image = process_pose_image(pose_image)
 
106
 
107
- # Prepare the concatenated input
108
- concatenated_input, ref_width, pose_width = prepare_pose_reference_pair(
109
- reference_image,
110
- pose_image
111
- )
112
 
113
  if concatenated_input is None:
114
- raise gr.Error("Failed to process the input images.")
115
-
116
- # Construct the prompt with the trigger word
117
- base_prompt = "refcontrolpose"
118
 
 
119
  if prompt:
120
- # User-provided prompt with trigger word
121
- full_prompt = f"{base_prompt}, {prompt}"
122
  else:
123
- # Default prompt for pose transfer
124
- full_prompt = f"{base_prompt}, transfer the pose from the right image to the subject in the left image, maintaining the identity, clothing, and style of the original subject while adopting the exact pose and body position shown in the control map"
125
-
126
- # Add instruction for the model to understand the layout
127
- full_prompt += ". The left side shows the reference with identity to preserve, the right side shows the target pose to follow."
128
-
129
- # Generate the image
130
- with torch.autocast("cuda"):
131
- result = pipe(
132
- image=concatenated_input,
133
- prompt=full_prompt,
134
- guidance_scale=guidance_scale,
135
- num_inference_steps=28,
136
- width=concatenated_input.size[0],
137
- height=concatenated_input.size[1],
138
- generator=torch.Generator("cuda").manual_seed(seed),
139
- ).images[0]
140
 
141
- # Optional: Crop the result to show only the transformed subject
142
- # You might want to crop out the concatenated input and show only the result
143
 
144
- return result, seed, concatenated_input
145
-
146
- def create_pose_from_image(image):
147
- """
148
- Helper function to extract pose/line art from an image.
149
- This is a placeholder - you might want to integrate with OpenPose or similar.
150
- """
151
- if image is None:
152
- return None
153
 
154
- # Placeholder: In production, you'd use OpenPose or similar
155
- # For now, we'll just convert to grayscale as a simple edge detection
156
- from PIL import ImageFilter, ImageOps
157
-
158
- image = image.convert("L") # Convert to grayscale
159
- image = image.filter(ImageFilter.FIND_EDGES) # Simple edge detection
160
- image = ImageOps.invert(image) # Invert to get black lines on white
161
- image = image.convert("RGB") # Convert back to RGB
162
-
163
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  # CSS styling
166
  css = """
167
  #col-container {
168
  margin: 0 auto;
169
- max-width: 1200px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  }
171
- .image-container {
172
  border: 2px solid #e0e0e0;
173
  border-radius: 8px;
174
- padding: 10px;
175
- background: #f9f9f9;
176
  }
177
- .result-container {
178
  border: 3px solid #4CAF50;
179
  border-radius: 8px;
180
- padding: 10px;
181
- background: #f0f8f0;
182
  }
183
  """
184
 
185
  # Create Gradio interface
186
- with gr.Blocks(css=css) as demo:
187
 
188
  with gr.Column(elem_id="col-container"):
189
- gr.Markdown("""
190
- # 🎭 FLUX.1 Kontext Reference Pose Transfer
191
-
192
- Transfer any pose to your subject while preserving their identity and style!
193
-
194
- **How it works:**
195
- 1. Upload a **reference image** (your subject with identity/style to preserve)
196
- 2. Upload a **pose image** (line art or pose skeleton to follow)
197
- 3. The model will generate your subject in the new pose
 
 
 
 
 
 
 
198
 
199
- Uses the **refcontrolpose** LoRA for precise pose control.
200
- """)
 
 
 
 
 
 
 
 
201
 
 
202
  with gr.Row():
203
  with gr.Column(scale=1):
204
- gr.Markdown("### πŸ“Έ Input Images")
205
 
206
- with gr.Row():
207
- with gr.Column():
208
- reference_image = gr.Image(
209
- label="Reference Image (Subject)",
210
- type="pil",
211
- elem_classes=["image-container"],
212
- height=300
213
- )
214
- gr.Markdown("*Upload the image with the subject/style to preserve*")
215
-
216
- with gr.Column():
217
- pose_image = gr.Image(
218
- label="Pose Control (Line Art)",
219
- type="pil",
220
- elem_classes=["image-container"],
221
- height=300
222
- )
223
- gr.Markdown("*Upload the pose/line art to follow*")
224
 
225
- # Optional: Add pose extraction tool
 
 
 
 
 
 
 
 
226
  with gr.Accordion("πŸ”§ Extract Pose from Image", open=False):
227
- source_for_pose = gr.Image(
228
- label="Source Image for Pose Extraction",
229
  type="pil",
230
  height=200
231
  )
232
- extract_pose_btn = gr.Button("Extract Pose", size="sm")
233
 
 
234
  prompt = gr.Textbox(
235
- label="Additional Prompt (Optional)",
236
- placeholder="e.g., wearing a red dress, in a garden, professional photography",
237
- info="Add details about the desired output (trigger word 'refcontrolpose' is added automatically)",
238
  lines=2
239
  )
240
 
241
- with gr.Row():
242
- run_button = gr.Button("🎨 Transfer Pose", variant="primary", scale=2)
243
- clear_button = gr.Button("πŸ—‘οΈ Clear", scale=1)
 
 
244
 
 
 
 
 
 
 
 
 
245
  with gr.Accordion("βš™οΈ Advanced Settings", open=False):
246
-
247
- seed = gr.Slider(
248
- label="Seed",
249
- minimum=0,
250
- maximum=MAX_SEED,
251
- step=1,
252
- value=42,
253
- )
254
-
255
- randomize_seed = gr.Checkbox(
256
- label="Randomize seed",
257
- value=True
258
- )
259
 
260
  guidance_scale = gr.Slider(
261
  label="Guidance Scale",
@@ -263,84 +337,141 @@ with gr.Blocks(css=css) as demo:
263
  maximum=10.0,
264
  step=0.5,
265
  value=3.5,
266
- info="Higher values follow the pose more strictly"
267
  )
268
 
269
- strength = gr.Slider(
270
- label="Transformation Strength",
271
- minimum=0.1,
272
- maximum=1.0,
273
- step=0.05,
274
- value=0.85,
275
- info="How much to change from the original"
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  )
277
 
278
  with gr.Column(scale=1):
279
- gr.Markdown("### πŸ–ΌοΈ Results")
280
 
281
- result = gr.Image(
282
- label="Generated Result",
283
- elem_classes=["result-container"],
 
284
  interactive=False,
285
- height=400
286
  )
287
 
288
- with gr.Accordion("πŸ“Š Generation Info", open=False):
289
- used_seed = gr.Number(label="Seed Used", interactive=False)
290
- input_preview = gr.Image(
291
- label="Concatenated Input (Reference | Pose)",
 
 
 
 
 
 
 
292
  height=200
293
  )
294
 
 
295
  with gr.Row():
296
- save_button = gr.Button("πŸ’Ύ Save Result", size="sm")
297
- reuse_button = gr.Button("♻️ Use as Reference", size="sm")
 
298
 
299
  # Examples
300
- gr.Markdown("### πŸ’‘ Examples")
301
  gr.Examples(
302
  examples=[
303
- ["A person in business attire", "standing confidently"],
304
- ["A dancer in elegant costume", "performing a ballet leap"],
305
- ["An athlete in sportswear", "doing a martial arts kick"],
306
- ["A model in casual outfit", "sitting on a chair"],
 
 
307
  ],
308
- inputs=[prompt],
309
- label="Example Prompts"
310
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
  # Event handlers
313
- run_button.click(
314
- fn=infer_pose_transfer,
315
  inputs=[
316
- reference_image,
317
- pose_image,
318
- prompt,
319
- seed,
320
- randomize_seed,
 
321
  guidance_scale,
322
- strength
 
 
323
  ],
324
- outputs=[result, used_seed, input_preview]
325
  )
326
 
327
- extract_pose_btn.click(
328
- fn=create_pose_from_image,
329
- inputs=[source_for_pose],
330
  outputs=[pose_image]
331
  )
332
 
333
- reuse_button.click(
334
- fn=lambda img: img,
335
- inputs=[result],
336
  outputs=[reference_image]
337
  )
338
 
339
- clear_button.click(
340
- fn=lambda: [None, None, "", None, 42, None],
341
- outputs=[reference_image, pose_image, prompt, result, used_seed, input_preview]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  )
343
 
344
  # Launch the app
345
- demo.queue()
346
- demo.launch()
 
 
1
  import gradio as gr
2
  import numpy as np
 
3
  import torch
4
  import random
5
+ import os
6
+ import spaces
7
+ from PIL import Image, ImageOps, ImageFilter
8
+ from diffusers import FluxPipeline, DiffusionPipeline
9
+ from diffusers.loaders import LoraLoaderMixin
10
+ import requests
11
+ from io import BytesIO
12
 
13
+ # Constants
14
  MAX_SEED = np.iinfo(np.int32).max
15
+ HF_TOKEN = os.getenv("HF_TOKEN")
16
+
17
+ # Model configuration
18
+ KONTEXT_MODEL = "black-forest-labs/FLUX.1-Kontext-dev"
19
+ LORA_MODEL = "thedeoxen/refcontrol-flux-kontext-reference-pose-lora"
20
+ TRIGGER_WORD = "refcontrolpose"
21
 
22
+ # Initialize pipeline with authentication
23
+ print("Loading models...")
 
 
 
24
 
25
+ try:
26
+ # Load Flux Kontext pipeline with HF token
27
+ if HF_TOKEN:
28
+ from diffusers import FluxKontextPipeline
29
+ pipe = FluxKontextPipeline.from_pretrained(
30
+ KONTEXT_MODEL,
31
+ torch_dtype=torch.bfloat16,
32
+ use_auth_token=HF_TOKEN
33
+ )
34
+
35
+ # Load the RefControl LoRA
36
+ pipe.load_lora_weights(
37
+ LORA_MODEL,
38
+ adapter_name="refcontrol",
39
+ use_auth_token=HF_TOKEN
40
+ )
41
+
42
+ # Move to GPU
43
+ pipe = pipe.to("cuda")
44
+
45
+ MODEL_STATUS = "βœ… Flux Kontext + RefControl LoRA loaded successfully"
46
+ print(MODEL_STATUS)
47
+
48
+ else:
49
+ raise ValueError("HF_TOKEN not found in environment variables")
50
+
51
+ except Exception as e:
52
+ print(f"Error loading models: {e}")
53
+ # Fallback to base model without LoRA
54
+ try:
55
+ pipe = DiffusionPipeline.from_pretrained(
56
+ "black-forest-labs/FLUX.1-dev",
57
+ torch_dtype=torch.bfloat16,
58
+ use_auth_token=HF_TOKEN if HF_TOKEN else True
59
+ ).to("cuda")
60
+ MODEL_STATUS = "⚠️ Running in fallback mode (FLUX.1-dev without LoRA)"
61
+ except:
62
+ MODEL_STATUS = "❌ Failed to load models. Please check HF_TOKEN"
63
+ pipe = None
64
 
65
+ def prepare_images_for_kontext(reference_image, pose_image, target_size=768):
66
  """
67
+ Prepare reference and pose images for Kontext processing.
68
+ Following the RefControl format: reference (left) | pose (right)
 
 
 
 
 
 
69
  """
70
  if reference_image is None or pose_image is None:
71
  return None
72
 
73
+ # Convert to RGB
74
  reference_image = reference_image.convert("RGB")
75
  pose_image = pose_image.convert("RGB")
76
 
77
+ # Calculate dimensions maintaining aspect ratio
 
 
 
78
  ref_ratio = reference_image.width / reference_image.height
79
  pose_ratio = pose_image.width / pose_image.height
80
 
81
+ # Set heights to target size
82
+ height = target_size
83
+ ref_width = int(height * ref_ratio)
84
+ pose_width = int(height * pose_ratio)
85
 
86
+ # Ensure dimensions are divisible by 8 (FLUX requirement)
87
  ref_width = (ref_width // 8) * 8
88
  pose_width = (pose_width // 8) * 8
89
+ height = (height // 8) * 8
90
 
91
  # Resize images
92
+ reference_resized = reference_image.resize((ref_width, height), Image.LANCZOS)
93
+ pose_resized = pose_image.resize((pose_width, height), Image.LANCZOS)
94
 
95
+ # Concatenate horizontally: reference | pose
96
  total_width = ref_width + pose_width
97
+ concatenated = Image.new('RGB', (total_width, height))
 
 
98
  concatenated.paste(reference_resized, (0, 0))
99
  concatenated.paste(pose_resized, (ref_width, 0))
100
 
101
+ return concatenated
102
 
103
+ def extract_pose_edges(image):
104
  """
105
+ Extract edge/pose information from an image.
106
  """
107
+ if image is None:
108
  return None
109
 
110
+ # Convert to grayscale
111
+ gray = image.convert("L")
112
+
113
+ # Apply edge detection
114
+ edges = gray.filter(ImageFilter.FIND_EDGES)
115
+
116
+ # Enhance contrast
117
+ edges = ImageOps.autocontrast(edges)
118
+
119
+ # Invert to get black lines on white
120
+ edges = ImageOps.invert(edges)
121
 
122
+ # Smooth the result
123
+ edges = edges.filter(ImageFilter.SMOOTH_MORE)
124
 
125
+ # Convert back to RGB
126
+ return edges.convert("RGB")
127
 
128
+ @spaces.GPU(duration=60)
129
+ def generate_pose_transfer(
130
+ reference_image,
131
+ pose_image,
132
+ prompt="",
133
+ negative_prompt="",
134
+ seed=42,
135
+ randomize_seed=False,
136
  guidance_scale=3.5,
137
+ num_inference_steps=28,
138
+ lora_scale=1.0,
139
+ enhance_pose=False,
140
  progress=gr.Progress(track_tqdm=True)
141
  ):
142
  """
143
+ Main generation function using RefControl LoRA.
144
  """
145
 
146
+ if pipe is None:
147
+ return None, 0, "Model not loaded. Please check HF_TOKEN"
148
+
149
  if reference_image is None or pose_image is None:
150
+ raise gr.Error("Please upload both reference and pose images")
151
 
152
+ # Randomize seed if requested
153
  if randomize_seed:
154
  seed = random.randint(0, MAX_SEED)
155
 
156
+ # Enhance pose if requested
157
+ if enhance_pose:
158
+ pose_image = extract_pose_edges(pose_image)
159
 
160
+ # Prepare concatenated input
161
+ concatenated_input = prepare_images_for_kontext(reference_image, pose_image)
 
 
 
162
 
163
  if concatenated_input is None:
164
+ raise gr.Error("Failed to process images")
 
 
 
165
 
166
+ # Construct prompt with trigger word
167
  if prompt:
168
+ full_prompt = f"{TRIGGER_WORD}, {prompt}"
 
169
  else:
170
+ full_prompt = f"{TRIGGER_WORD}, transfer the pose from the right image to the subject in the left image while maintaining their identity, clothing, and style"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ # Add instruction for the model
173
+ full_prompt += ". The left image shows the reference subject, the right image shows the target pose."
174
 
175
+ # Set generator for reproducibility
176
+ generator = torch.Generator("cuda").manual_seed(seed)
 
 
 
 
 
 
 
177
 
178
+ try:
179
+ # Generate with LoRA
180
+ with torch.autocast("cuda"):
181
+ if hasattr(pipe, 'set_adapters'):
182
+ # Set LoRA adapter strength
183
+ pipe.set_adapters(["refcontrol"], adapter_weights=[lora_scale])
184
+
185
+ # Generate image
186
+ result = pipe(
187
+ image=concatenated_input,
188
+ prompt=full_prompt,
189
+ negative_prompt=negative_prompt,
190
+ guidance_scale=guidance_scale,
191
+ num_inference_steps=num_inference_steps,
192
+ generator=generator,
193
+ width=concatenated_input.width,
194
+ height=concatenated_input.height,
195
+ ).images[0]
196
+
197
+ return result, seed, concatenated_input
198
+
199
+ except Exception as e:
200
+ raise gr.Error(f"Generation failed: {str(e)}")
201
 
202
  # CSS styling
203
  css = """
204
  #col-container {
205
  margin: 0 auto;
206
+ max-width: 1280px;
207
+ }
208
+ .header {
209
+ text-align: center;
210
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
211
+ padding: 20px;
212
+ border-radius: 12px;
213
+ margin-bottom: 20px;
214
+ }
215
+ .header h1 {
216
+ color: white;
217
+ margin: 0;
218
+ }
219
+ .status-box {
220
+ padding: 10px;
221
+ border-radius: 8px;
222
+ margin: 10px 0;
223
+ font-weight: bold;
224
  }
225
+ .input-image {
226
  border: 2px solid #e0e0e0;
227
  border-radius: 8px;
228
+ overflow: hidden;
 
229
  }
230
+ .result-image {
231
  border: 3px solid #4CAF50;
232
  border-radius: 8px;
233
+ overflow: hidden;
 
234
  }
235
  """
236
 
237
  # Create Gradio interface
238
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
239
 
240
  with gr.Column(elem_id="col-container"):
241
+ # Header with authentication
242
+ with gr.Row():
243
+ with gr.Column():
244
+ gr.HTML("""
245
+ <div class="header">
246
+ <h1>🎭 RefControl Flux Kontext - Reference Pose Transfer</h1>
247
+ <p style="color: white;">Powered by thedeoxen/refcontrol-flux-kontext-reference-pose-lora</p>
248
+ </div>
249
+ """)
250
+
251
+ # Model status
252
+ gr.Markdown(f"""
253
+ <div class="status-box" style="background: {'#d4edda' if 'βœ…' in MODEL_STATUS else '#f8d7da'};">
254
+ {MODEL_STATUS}
255
+ </div>
256
+ """)
257
 
258
+ # Authentication info
259
+ if not HF_TOKEN:
260
+ gr.Markdown("""
261
+ ### πŸ” Authentication Required
262
+ Please set your Hugging Face token to use this Space:
263
+ 1. Go to Settings β†’ Variables and secrets
264
+ 2. Add `HF_TOKEN` with your Hugging Face token
265
+ 3. Restart the Space
266
+ """)
267
+ gr.LoginButton("Sign in with Hugging Face", size="lg")
268
 
269
+ # Main interface
270
  with gr.Row():
271
  with gr.Column(scale=1):
272
+ gr.Markdown("### πŸ“₯ Input Images")
273
 
274
+ # Reference image
275
+ reference_image = gr.Image(
276
+ label="Reference Image (Subject to transform)",
277
+ type="pil",
278
+ elem_classes=["input-image"],
279
+ height=300
280
+ )
 
 
 
 
 
 
 
 
 
 
 
281
 
282
+ # Pose image
283
+ pose_image = gr.Image(
284
+ label="Pose Control (Line art or skeleton)",
285
+ type="pil",
286
+ elem_classes=["input-image"],
287
+ height=300
288
+ )
289
+
290
+ # Pose extraction tool
291
  with gr.Accordion("πŸ”§ Extract Pose from Image", open=False):
292
+ extract_source = gr.Image(
293
+ label="Source image for pose extraction",
294
  type="pil",
295
  height=200
296
  )
297
+ extract_btn = gr.Button("Extract Pose", size="sm")
298
 
299
+ # Prompts
300
  prompt = gr.Textbox(
301
+ label=f"Prompt (trigger word '{TRIGGER_WORD}' added automatically)",
302
+ placeholder="e.g., wearing elegant dress, professional photography",
 
303
  lines=2
304
  )
305
 
306
+ negative_prompt = gr.Textbox(
307
+ label="Negative Prompt",
308
+ placeholder="e.g., blurry, low quality, distorted",
309
+ lines=2
310
+ )
311
 
312
+ # Generate button
313
+ generate_btn = gr.Button(
314
+ "🎨 Generate Pose Transfer",
315
+ variant="primary",
316
+ size="lg"
317
+ )
318
+
319
+ # Advanced settings
320
  with gr.Accordion("βš™οΈ Advanced Settings", open=False):
321
+ with gr.Row():
322
+ seed = gr.Slider(
323
+ label="Seed",
324
+ minimum=0,
325
+ maximum=MAX_SEED,
326
+ step=1,
327
+ value=42
328
+ )
329
+ randomize_seed = gr.Checkbox(
330
+ label="Randomize",
331
+ value=True
332
+ )
 
333
 
334
  guidance_scale = gr.Slider(
335
  label="Guidance Scale",
 
337
  maximum=10.0,
338
  step=0.5,
339
  value=3.5,
340
+ info="How strictly to follow the pose"
341
  )
342
 
343
+ num_inference_steps = gr.Slider(
344
+ label="Inference Steps",
345
+ minimum=20,
346
+ maximum=50,
347
+ step=1,
348
+ value=28
349
+ )
350
+
351
+ lora_scale = gr.Slider(
352
+ label="LoRA Strength",
353
+ minimum=0.0,
354
+ maximum=2.0,
355
+ step=0.1,
356
+ value=1.0,
357
+ info="RefControl LoRA influence"
358
+ )
359
+
360
+ enhance_pose = gr.Checkbox(
361
+ label="Auto-enhance pose edges",
362
+ value=False
363
  )
364
 
365
  with gr.Column(scale=1):
366
+ gr.Markdown("### πŸ–ΌοΈ Generated Result")
367
 
368
+ # Result image
369
+ result_image = gr.Image(
370
+ label="Generated Image",
371
+ elem_classes=["result-image"],
372
  interactive=False,
373
+ height=500
374
  )
375
 
376
+ # Info display
377
+ with gr.Row():
378
+ seed_used = gr.Number(
379
+ label="Seed Used",
380
+ interactive=False
381
+ )
382
+
383
+ # Debug view
384
+ with gr.Accordion("πŸ” Debug View", open=False):
385
+ concat_preview = gr.Image(
386
+ label="Input Concatenation (Reference | Pose)",
387
  height=200
388
  )
389
 
390
+ # Reuse buttons
391
  with gr.Row():
392
+ reuse_ref_btn = gr.Button("♻️ Use as Reference", size="sm")
393
+ reuse_pose_btn = gr.Button("πŸ“ Extract & Use as Pose", size="sm")
394
+ clear_btn = gr.Button("πŸ—‘οΈ Clear All", size="sm")
395
 
396
  # Examples
397
+ gr.Markdown("### πŸ’‘ Example Prompts")
398
  gr.Examples(
399
  examples=[
400
+ ["professional portrait, studio lighting, high quality"],
401
+ ["wearing red dress, outdoor garden setting"],
402
+ ["business attire, corporate headshot"],
403
+ ["casual streetwear, urban background"],
404
+ ["athletic wear, dynamic action shot"],
405
+ ["elegant evening gown, luxury setting"],
406
  ],
407
+ inputs=[prompt]
 
408
  )
409
+
410
+ # Instructions
411
+ with gr.Accordion("πŸ“– How to Use", open=False):
412
+ gr.Markdown("""
413
+ 1. **Upload Reference Image**: The person/subject you want to transform
414
+ 2. **Upload Pose Image**: Line art or skeleton pose to follow
415
+ 3. **Optional**: Add descriptive prompt for style/setting
416
+ 4. **Click Generate**: Wait for the magic to happen!
417
+
418
+ **Tips:**
419
+ - Use clear, high-contrast pose images for best results
420
+ - The model preserves identity from reference while following pose
421
+ - Adjust LoRA strength to balance identity vs pose adherence
422
+ - Higher guidance scale = stricter pose following
423
+ """)
424
 
425
  # Event handlers
426
+ generate_btn.click(
427
+ fn=generate_pose_transfer,
428
  inputs=[
429
+ reference_image,
430
+ pose_image,
431
+ prompt,
432
+ negative_prompt,
433
+ seed,
434
+ randomize_seed,
435
  guidance_scale,
436
+ num_inference_steps,
437
+ lora_scale,
438
+ enhance_pose
439
  ],
440
+ outputs=[result_image, seed_used, concat_preview]
441
  )
442
 
443
+ extract_btn.click(
444
+ fn=extract_pose_edges,
445
+ inputs=[extract_source],
446
  outputs=[pose_image]
447
  )
448
 
449
+ reuse_ref_btn.click(
450
+ fn=lambda x: x,
451
+ inputs=[result_image],
452
  outputs=[reference_image]
453
  )
454
 
455
+ reuse_pose_btn.click(
456
+ fn=extract_pose_edges,
457
+ inputs=[result_image],
458
+ outputs=[pose_image]
459
+ )
460
+
461
+ clear_btn.click(
462
+ fn=lambda: [None, None, "", "", 42, None, None],
463
+ outputs=[
464
+ reference_image,
465
+ pose_image,
466
+ prompt,
467
+ negative_prompt,
468
+ seed_used,
469
+ result_image,
470
+ concat_preview
471
+ ]
472
  )
473
 
474
  # Launch the app
475
+ if __name__ == "__main__":
476
+ demo.queue()
477
+ demo.launch()