aiqtech commited on
Commit
ee482b5
·
verified ·
1 Parent(s): 7239641

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -19
app.py CHANGED
@@ -80,7 +80,7 @@ def load_pipeline():
80
  pipe, MODEL_STATUS = load_pipeline()
81
  print(MODEL_STATUS)
82
 
83
- def prepare_images_for_kontext(reference_image, pose_image, target_size=768):
84
  """
85
  Prepare reference and pose images for Kontext processing.
86
  Following the RefControl format: reference (left) | pose (right)
@@ -175,12 +175,36 @@ def generate_pose_transfer(
175
  if enhance_pose:
176
  pose_image = extract_pose_edges(pose_image)
177
 
178
- # Prepare concatenated input
179
- concatenated_input = prepare_images_for_kontext(reference_image, pose_image)
180
 
181
  if concatenated_input is None:
182
  raise gr.Error("Failed to process images")
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  # Construct prompt with trigger word
185
  if prompt:
186
  full_prompt = f"{TRIGGER_WORD}, {prompt}"
@@ -195,43 +219,55 @@ def generate_pose_transfer(
195
 
196
  try:
197
  # Check if we have LoRA capabilities
198
- has_lora = hasattr(pipe, 'set_adapters') and "RefControl" in MODEL_STATUS
199
 
200
- with torch.autocast("cuda"):
201
- if has_lora:
202
- # Try to set LoRA adapter strength
203
- try:
204
- pipe.set_adapters(["refcontrol"], adapter_weights=[lora_scale])
205
- except Exception as e:
206
- print(f"Could not set LoRA adapter: {e}")
207
-
208
- # Generate image based on pipeline type
 
 
 
 
209
  if "Kontext" in MODEL_STATUS:
210
  # Use Kontext pipeline
211
  result = pipe(
212
  image=concatenated_input,
213
  prompt=full_prompt,
214
- negative_prompt=negative_prompt if negative_prompt else None,
215
  guidance_scale=guidance_scale,
216
  num_inference_steps=num_inference_steps,
217
  generator=generator,
218
- width=concatenated_input.width,
219
- height=concatenated_input.height,
220
  ).images[0]
221
  else:
222
- # Use standard FLUX pipeline (image-to-image)
223
  result = pipe(
224
  prompt=full_prompt,
 
225
  image=concatenated_input,
226
  guidance_scale=guidance_scale,
227
  num_inference_steps=num_inference_steps,
228
  generator=generator,
229
- strength=0.85, # For img2img mode
230
  ).images[0]
231
-
 
232
  return result, seed, concatenated_input
233
 
 
 
 
 
 
234
  except Exception as e:
 
235
  raise gr.Error(f"Generation failed: {str(e)}")
236
 
237
  # CSS styling
 
80
  pipe, MODEL_STATUS = load_pipeline()
81
  print(MODEL_STATUS)
82
 
83
+ def prepare_images_for_kontext(reference_image, pose_image, target_size=512):
84
  """
85
  Prepare reference and pose images for Kontext processing.
86
  Following the RefControl format: reference (left) | pose (right)
 
175
  if enhance_pose:
176
  pose_image = extract_pose_edges(pose_image)
177
 
178
+ # Prepare concatenated input with fixed size
179
+ concatenated_input = prepare_images_for_kontext(reference_image, pose_image, target_size=512)
180
 
181
  if concatenated_input is None:
182
  raise gr.Error("Failed to process images")
183
 
184
+ # Ensure dimensions are model-compatible
185
+ width, height = concatenated_input.size
186
+ # Round to nearest 64 pixels for stability
187
+ width = (width // 64) * 64
188
+ height = (height // 64) * 64
189
+
190
+ # Limit maximum size to prevent memory issues
191
+ max_size = 1024
192
+ if width > max_size:
193
+ ratio = max_size / width
194
+ width = max_size
195
+ height = int(height * ratio)
196
+ height = (height // 64) * 64
197
+
198
+ if height > max_size:
199
+ ratio = max_size / height
200
+ height = max_size
201
+ width = int(width * ratio)
202
+ width = (width // 64) * 64
203
+
204
+ # Resize if needed
205
+ if (width, height) != concatenated_input.size:
206
+ concatenated_input = concatenated_input.resize((width, height), Image.LANCZOS)
207
+
208
  # Construct prompt with trigger word
209
  if prompt:
210
  full_prompt = f"{TRIGGER_WORD}, {prompt}"
 
219
 
220
  try:
221
  # Check if we have LoRA capabilities
222
+ has_lora = hasattr(pipe, 'set_adapters') and "LoRA" in MODEL_STATUS
223
 
224
+ # Set LoRA if available
225
+ if has_lora:
226
+ try:
227
+ pipe.set_adapters(["refcontrol"], adapter_weights=[lora_scale])
228
+ print(f"LoRA adapter set with strength: {lora_scale}")
229
+ except Exception as e:
230
+ print(f"LoRA adapter not set: {e}")
231
+
232
+ print(f"Generating with size: {width}x{height}")
233
+ print(f"Prompt: {full_prompt[:100]}...")
234
+
235
+ # Generate image
236
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
237
  if "Kontext" in MODEL_STATUS:
238
  # Use Kontext pipeline
239
  result = pipe(
240
  image=concatenated_input,
241
  prompt=full_prompt,
242
+ negative_prompt=negative_prompt if negative_prompt else "",
243
  guidance_scale=guidance_scale,
244
  num_inference_steps=num_inference_steps,
245
  generator=generator,
246
+ width=width,
247
+ height=height,
248
  ).images[0]
249
  else:
250
+ # Use standard FLUX pipeline
251
  result = pipe(
252
  prompt=full_prompt,
253
+ negative_prompt=negative_prompt if negative_prompt else "",
254
  image=concatenated_input,
255
  guidance_scale=guidance_scale,
256
  num_inference_steps=num_inference_steps,
257
  generator=generator,
258
+ strength=0.85,
259
  ).images[0]
260
+
261
+ print("Generation successful!")
262
  return result, seed, concatenated_input
263
 
264
+ except RuntimeError as e:
265
+ if "out of memory" in str(e).lower():
266
+ raise gr.Error("GPU out of memory. Try reducing image size or inference steps.")
267
+ else:
268
+ raise gr.Error(f"Generation failed: {str(e)}")
269
  except Exception as e:
270
+ print(f"Error details: {e}")
271
  raise gr.Error(f"Generation failed: {str(e)}")
272
 
273
  # CSS styling