aiqtech commited on
Commit
b22198b
ยท
verified ยท
1 Parent(s): 12007e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +274 -134
app.py CHANGED
@@ -3,204 +3,344 @@ import numpy as np
3
  import spaces
4
  import torch
5
  import random
6
- from PIL import Image
7
- #from kontext_pipeline import FluxKontextPipeline
8
  from diffusers import FluxKontextPipeline
9
  from diffusers.utils import load_image
10
 
11
- # Load Kontext model
12
  MAX_SEED = np.iinfo(np.int32).max
13
 
14
- pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
 
 
 
 
15
 
16
- def concatenate_images(images, direction="horizontal"):
 
 
 
 
17
  """
18
- Concatenate multiple PIL images either horizontally or vertically.
19
 
20
  Args:
21
- images: List of PIL Images
22
- direction: "horizontal" or "vertical"
23
 
24
  Returns:
25
- PIL Image: Concatenated image
26
  """
27
- if not images:
28
  return None
29
 
30
- # Filter out None images
31
- valid_images = [img for img in images if img is not None]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- if not valid_images:
 
 
 
 
 
 
34
  return None
35
 
36
- if len(valid_images) == 1:
37
- return valid_images[0].convert("RGB")
38
 
39
- # Convert all images to RGB
40
- valid_images = [img.convert("RGB") for img in valid_images]
41
 
42
- if direction == "horizontal":
43
- # Calculate total width and max height
44
- total_width = sum(img.width for img in valid_images)
45
- max_height = max(img.height for img in valid_images)
46
-
47
- # Create new image
48
- concatenated = Image.new('RGB', (total_width, max_height), (255, 255, 255))
49
-
50
- # Paste images
51
- x_offset = 0
52
- for img in valid_images:
53
- # Center image vertically if heights differ
54
- y_offset = (max_height - img.height) // 2
55
- concatenated.paste(img, (x_offset, y_offset))
56
- x_offset += img.width
57
-
58
- else: # vertical
59
- # Calculate max width and total height
60
- max_width = max(img.width for img in valid_images)
61
- total_height = sum(img.height for img in valid_images)
62
-
63
- # Create new image
64
- concatenated = Image.new('RGB', (max_width, total_height), (255, 255, 255))
65
-
66
- # Paste images
67
- y_offset = 0
68
- for img in valid_images:
69
- # Center image horizontally if widths differ
70
- x_offset = (max_width - img.width) // 2
71
- concatenated.paste(img, (x_offset, y_offset))
72
- y_offset += img.height
73
-
74
- return concatenated
75
 
76
  @spaces.GPU
77
- def infer(input_images, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  if randomize_seed:
80
  seed = random.randint(0, MAX_SEED)
81
 
82
- # Handle input_images - it could be a single image or a list of images
83
- if input_images is None:
84
- raise gr.Error("Please upload at least one image.")
85
 
86
- # If it's a single image (not a list), convert to list
87
- if not isinstance(input_images, list):
88
- input_images = [input_images]
 
 
89
 
90
- # Filter out None images
91
- valid_images = [img[0] for img in input_images if img is not None]
92
 
93
- if not valid_images:
94
- raise gr.Error("Please upload at least one valid image.")
95
 
96
- # Concatenate images horizontally
97
- concatenated_image = concatenate_images(valid_images, "horizontal")
 
 
 
 
98
 
99
- if concatenated_image is None:
100
- raise gr.Error("Failed to process the input images.")
101
 
102
- # original_width, original_height = concatenated_image.size
 
 
 
 
 
 
 
 
 
 
103
 
104
- # if original_width >= original_height:
105
- # new_width = 1024
106
- # new_height = int(original_height * (new_width / original_width))
107
- # new_height = round(new_height / 64) * 64
108
- # else:
109
- # new_height = 1024
110
- # new_width = int(original_width * (new_height / original_height))
111
- # new_width = round(new_width / 64) * 64
112
 
113
- #concatenated_image_resized = concatenated_image.resize((new_width, new_height), Image.LANCZOS)
114
 
115
- final_prompt = f"From the provided reference images, create a unified, cohesive image such that {prompt}. Maintain the identity and characteristics of each subject while adjusting their proportions, scale, and positioning to create a harmonious, naturally balanced composition. Blend and integrate all elements seamlessly with consistent lighting, perspective, and style.the final result should look like a single naturally captured scene where all subjects are properly sized and positioned relative to each other, not assembled from multiple sources."
116
-
117
- image = pipe(
118
- image=concatenated_image,
119
- prompt=final_prompt,
120
- guidance_scale=guidance_scale,
121
- width=concatenated_image.size[0],
122
- height=concatenated_image.size[1],
123
- generator=torch.Generator().manual_seed(seed),
124
- ).images[0]
125
-
126
- return image, seed, gr.update(visible=True)
 
 
 
 
 
 
127
 
128
- css="""
 
129
  #col-container {
130
  margin: 0 auto;
131
- max-width: 960px;
 
 
 
 
 
 
 
 
 
 
 
 
132
  }
133
  """
134
 
 
135
  with gr.Blocks(css=css) as demo:
136
 
137
  with gr.Column(elem_id="col-container"):
138
- gr.Markdown(f"""# FLUX.1 Kontext [dev] - Multi-Image
139
- Flux Kontext with multiple image input support - compose a new image with elements from multiple images using Kontext [dev]
 
 
 
 
 
 
 
 
 
140
  """)
 
141
  with gr.Row():
142
- with gr.Column():
143
- input_images = gr.Gallery(
144
- label="Upload image(s) for editing",
145
- show_label=True,
146
- elem_id="gallery_input",
147
- columns=3,
148
- rows=2,
149
- object_fit="contain",
150
- height="auto",
151
- file_types=['image'],
152
- type='pil'
153
- )
154
-
155
-
156
 
157
  with gr.Row():
158
- prompt = gr.Text(
159
- label="Prompt",
160
- show_label=False,
161
- info = "describe the desired output composition",
162
- max_lines=1,
163
- placeholder="e.g. the dog from the left image sits on the bench from the right image",
164
- container=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
166
- run_button = gr.Button("Run", scale=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- with gr.Accordion("Advanced Settings", open=False):
169
-
170
  seed = gr.Slider(
171
  label="Seed",
172
  minimum=0,
173
  maximum=MAX_SEED,
174
  step=1,
175
- value=0,
176
  )
177
 
178
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
179
 
180
  guidance_scale = gr.Slider(
181
  label="Guidance Scale",
182
- minimum=1,
183
- maximum=10,
184
- step=0.1,
185
- value=2.5,
186
- )
 
187
 
188
- with gr.Column():
189
- result = gr.Image(label="Result", show_label=False, interactive=False)
190
- reuse_button = gr.Button("Reuse this image", visible=False)
191
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
- gr.on(
194
- triggers=[run_button.click, prompt.submit],
195
- fn = infer,
196
- inputs = [input_images, prompt, seed, randomize_seed, guidance_scale],
197
- outputs = [result, seed, reuse_button]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  )
199
 
200
  reuse_button.click(
201
- fn = lambda image: [image] if image is not None else [], # Convert single image to list for gallery
202
- inputs = [result],
203
- outputs = [input_images]
 
 
 
 
 
204
  )
205
 
 
 
206
  demo.launch()
 
3
  import spaces
4
  import torch
5
  import random
6
+ from PIL import Image, ImageOps
 
7
  from diffusers import FluxKontextPipeline
8
  from diffusers.utils import load_image
9
 
10
+ # Load Kontext model with Reference Pose LoRA
11
  MAX_SEED = np.iinfo(np.int32).max
12
 
13
+ # Initialize the pipeline
14
+ pipe = FluxKontextPipeline.from_pretrained(
15
+ "black-forest-labs/FLUX.1-Kontext-dev",
16
+ torch_dtype=torch.bfloat16
17
+ ).to("cuda")
18
 
19
+ # Load the Reference Pose LoRA (if available)
20
+ # Note: You'll need to add the actual LoRA loading code here
21
+ # pipe.load_lora_weights("path/to/refcontrol-pose-lora", adapter_name="refcontrol")
22
+
23
+ def prepare_pose_reference_pair(reference_image, pose_image):
24
  """
25
+ Prepare the reference image and pose control map for Kontext processing.
26
 
27
  Args:
28
+ reference_image: PIL Image - The source image with identity/style to preserve
29
+ pose_image: PIL Image - The pose/line art control map
30
 
31
  Returns:
32
+ PIL Image: Concatenated image with reference on left, pose on right
33
  """
34
+ if reference_image is None or pose_image is None:
35
  return None
36
 
37
+ # Convert images to RGB
38
+ reference_image = reference_image.convert("RGB")
39
+ pose_image = pose_image.convert("RGB")
40
+
41
+ # Resize images to have the same height for better concatenation
42
+ target_height = 768 # Standard height for Flux
43
+
44
+ # Calculate proportional widths
45
+ ref_ratio = reference_image.width / reference_image.height
46
+ pose_ratio = pose_image.width / pose_image.height
47
+
48
+ ref_width = int(target_height * ref_ratio)
49
+ pose_width = int(target_height * pose_ratio)
50
+
51
+ # Ensure dimensions are divisible by 8 (required for Flux)
52
+ ref_width = (ref_width // 8) * 8
53
+ pose_width = (pose_width // 8) * 8
54
+
55
+ # Resize images
56
+ reference_resized = reference_image.resize((ref_width, target_height), Image.LANCZOS)
57
+ pose_resized = pose_image.resize((pose_width, target_height), Image.LANCZOS)
58
+
59
+ # Create concatenated image: reference on left, pose on right
60
+ total_width = ref_width + pose_width
61
+ concatenated = Image.new('RGB', (total_width, target_height), (255, 255, 255))
62
+
63
+ # Paste images
64
+ concatenated.paste(reference_resized, (0, 0))
65
+ concatenated.paste(pose_resized, (ref_width, 0))
66
 
67
+ return concatenated, ref_width, pose_width
68
+
69
+ def process_pose_image(pose_image):
70
+ """
71
+ Process the pose image to enhance line art visibility if needed.
72
+ """
73
+ if pose_image is None:
74
  return None
75
 
76
+ pose_image = pose_image.convert("RGB")
 
77
 
78
+ # Optional: Enhance contrast for better pose detection
79
+ # You can add image processing here if the pose needs enhancement
80
 
81
+ return pose_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  @spaces.GPU
84
+ def infer_pose_transfer(
85
+ reference_image,
86
+ pose_image,
87
+ prompt="",
88
+ seed=42,
89
+ randomize_seed=False,
90
+ guidance_scale=3.5,
91
+ strength=0.85,
92
+ progress=gr.Progress(track_tqdm=True)
93
+ ):
94
+ """
95
+ Transfer pose from control image to reference image using Flux Kontext.
96
+ """
97
+
98
+ if reference_image is None or pose_image is None:
99
+ raise gr.Error("Please upload both a reference image and a pose image.")
100
 
101
  if randomize_seed:
102
  seed = random.randint(0, MAX_SEED)
103
 
104
+ # Process pose image if needed
105
+ pose_image = process_pose_image(pose_image)
 
106
 
107
+ # Prepare the concatenated input
108
+ concatenated_input, ref_width, pose_width = prepare_pose_reference_pair(
109
+ reference_image,
110
+ pose_image
111
+ )
112
 
113
+ if concatenated_input is None:
114
+ raise gr.Error("Failed to process the input images.")
115
 
116
+ # Construct the prompt with the trigger word
117
+ base_prompt = "refcontrolpose"
118
 
119
+ if prompt:
120
+ # User-provided prompt with trigger word
121
+ full_prompt = f"{base_prompt}, {prompt}"
122
+ else:
123
+ # Default prompt for pose transfer
124
+ full_prompt = f"{base_prompt}, transfer the pose from the right image to the subject in the left image, maintaining the identity, clothing, and style of the original subject while adopting the exact pose and body position shown in the control map"
125
 
126
+ # Add instruction for the model to understand the layout
127
+ full_prompt += ". The left side shows the reference with identity to preserve, the right side shows the target pose to follow."
128
 
129
+ # Generate the image
130
+ with torch.autocast("cuda"):
131
+ result = pipe(
132
+ image=concatenated_input,
133
+ prompt=full_prompt,
134
+ guidance_scale=guidance_scale,
135
+ num_inference_steps=28,
136
+ width=concatenated_input.size[0],
137
+ height=concatenated_input.size[1],
138
+ generator=torch.Generator("cuda").manual_seed(seed),
139
+ ).images[0]
140
 
141
+ # Optional: Crop the result to show only the transformed subject
142
+ # You might want to crop out the concatenated input and show only the result
 
 
 
 
 
 
143
 
144
+ return result, seed, concatenated_input
145
 
146
+ def create_pose_from_image(image):
147
+ """
148
+ Helper function to extract pose/line art from an image.
149
+ This is a placeholder - you might want to integrate with OpenPose or similar.
150
+ """
151
+ if image is None:
152
+ return None
153
+
154
+ # Placeholder: In production, you'd use OpenPose or similar
155
+ # For now, we'll just convert to grayscale as a simple edge detection
156
+ from PIL import ImageFilter, ImageOps
157
+
158
+ image = image.convert("L") # Convert to grayscale
159
+ image = image.filter(ImageFilter.FIND_EDGES) # Simple edge detection
160
+ image = ImageOps.invert(image) # Invert to get black lines on white
161
+ image = image.convert("RGB") # Convert back to RGB
162
+
163
+ return image
164
 
165
+ # CSS styling
166
+ css = """
167
  #col-container {
168
  margin: 0 auto;
169
+ max-width: 1200px;
170
+ }
171
+ .image-container {
172
+ border: 2px solid #e0e0e0;
173
+ border-radius: 8px;
174
+ padding: 10px;
175
+ background: #f9f9f9;
176
+ }
177
+ .result-container {
178
+ border: 3px solid #4CAF50;
179
+ border-radius: 8px;
180
+ padding: 10px;
181
+ background: #f0f8f0;
182
  }
183
  """
184
 
185
+ # Create Gradio interface
186
  with gr.Blocks(css=css) as demo:
187
 
188
  with gr.Column(elem_id="col-container"):
189
+ gr.Markdown("""
190
+ # ๐ŸŽญ FLUX.1 Kontext Reference Pose Transfer
191
+
192
+ Transfer any pose to your subject while preserving their identity and style!
193
+
194
+ **How it works:**
195
+ 1. Upload a **reference image** (your subject with identity/style to preserve)
196
+ 2. Upload a **pose image** (line art or pose skeleton to follow)
197
+ 3. The model will generate your subject in the new pose
198
+
199
+ Uses the **refcontrolpose** LoRA for precise pose control.
200
  """)
201
+
202
  with gr.Row():
203
+ with gr.Column(scale=1):
204
+ gr.Markdown("### ๐Ÿ“ธ Input Images")
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  with gr.Row():
207
+ with gr.Column():
208
+ reference_image = gr.Image(
209
+ label="Reference Image (Subject)",
210
+ type="pil",
211
+ elem_classes=["image-container"],
212
+ height=300
213
+ )
214
+ gr.Markdown("*Upload the image with the subject/style to preserve*")
215
+
216
+ with gr.Column():
217
+ pose_image = gr.Image(
218
+ label="Pose Control (Line Art)",
219
+ type="pil",
220
+ elem_classes=["image-container"],
221
+ height=300
222
+ )
223
+ gr.Markdown("*Upload the pose/line art to follow*")
224
+
225
+ # Optional: Add pose extraction tool
226
+ with gr.Accordion("๐Ÿ”ง Extract Pose from Image", open=False):
227
+ source_for_pose = gr.Image(
228
+ label="Source Image for Pose Extraction",
229
+ type="pil",
230
+ height=200
231
  )
232
+ extract_pose_btn = gr.Button("Extract Pose", size="sm")
233
+
234
+ prompt = gr.Textbox(
235
+ label="Additional Prompt (Optional)",
236
+ placeholder="e.g., wearing a red dress, in a garden, professional photography",
237
+ info="Add details about the desired output (trigger word 'refcontrolpose' is added automatically)",
238
+ lines=2
239
+ )
240
+
241
+ with gr.Row():
242
+ run_button = gr.Button("๐ŸŽจ Transfer Pose", variant="primary", scale=2)
243
+ clear_button = gr.Button("๐Ÿ—‘๏ธ Clear", scale=1)
244
+
245
+ with gr.Accordion("โš™๏ธ Advanced Settings", open=False):
246
 
 
 
247
  seed = gr.Slider(
248
  label="Seed",
249
  minimum=0,
250
  maximum=MAX_SEED,
251
  step=1,
252
+ value=42,
253
  )
254
 
255
+ randomize_seed = gr.Checkbox(
256
+ label="Randomize seed",
257
+ value=True
258
+ )
259
 
260
  guidance_scale = gr.Slider(
261
  label="Guidance Scale",
262
+ minimum=1.0,
263
+ maximum=10.0,
264
+ step=0.5,
265
+ value=3.5,
266
+ info="Higher values follow the pose more strictly"
267
+ )
268
 
269
+ strength = gr.Slider(
270
+ label="Transformation Strength",
271
+ minimum=0.1,
272
+ maximum=1.0,
273
+ step=0.05,
274
+ value=0.85,
275
+ info="How much to change from the original"
276
+ )
277
+
278
+ with gr.Column(scale=1):
279
+ gr.Markdown("### ๐Ÿ–ผ๏ธ Results")
280
+
281
+ result = gr.Image(
282
+ label="Generated Result",
283
+ elem_classes=["result-container"],
284
+ interactive=False,
285
+ height=400
286
+ )
287
+
288
+ with gr.Accordion("๐Ÿ“Š Generation Info", open=False):
289
+ used_seed = gr.Number(label="Seed Used", interactive=False)
290
+ input_preview = gr.Image(
291
+ label="Concatenated Input (Reference | Pose)",
292
+ height=200
293
+ )
294
+
295
+ with gr.Row():
296
+ save_button = gr.Button("๐Ÿ’พ Save Result", size="sm")
297
+ reuse_button = gr.Button("โ™ป๏ธ Use as Reference", size="sm")
298
 
299
+ # Examples
300
+ gr.Markdown("### ๐Ÿ’ก Examples")
301
+ gr.Examples(
302
+ examples=[
303
+ ["A person in business attire", "standing confidently"],
304
+ ["A dancer in elegant costume", "performing a ballet leap"],
305
+ ["An athlete in sportswear", "doing a martial arts kick"],
306
+ ["A model in casual outfit", "sitting on a chair"],
307
+ ],
308
+ inputs=[prompt],
309
+ label="Example Prompts"
310
+ )
311
+
312
+ # Event handlers
313
+ run_button.click(
314
+ fn=infer_pose_transfer,
315
+ inputs=[
316
+ reference_image,
317
+ pose_image,
318
+ prompt,
319
+ seed,
320
+ randomize_seed,
321
+ guidance_scale,
322
+ strength
323
+ ],
324
+ outputs=[result, used_seed, input_preview]
325
+ )
326
+
327
+ extract_pose_btn.click(
328
+ fn=create_pose_from_image,
329
+ inputs=[source_for_pose],
330
+ outputs=[pose_image]
331
  )
332
 
333
  reuse_button.click(
334
+ fn=lambda img: img,
335
+ inputs=[result],
336
+ outputs=[reference_image]
337
+ )
338
+
339
+ clear_button.click(
340
+ fn=lambda: [None, None, "", None, 42, None],
341
+ outputs=[reference_image, pose_image, prompt, result, used_seed, input_preview]
342
  )
343
 
344
+ # Launch the app
345
+ demo.queue()
346
  demo.launch()