aiqtech commited on
Commit
1226094
Β·
verified Β·
1 Parent(s): cd76271

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -111
app.py CHANGED
@@ -6,7 +6,6 @@ import os
6
  import spaces
7
  from PIL import Image, ImageOps, ImageFilter
8
  from diffusers import FluxPipeline, DiffusionPipeline
9
- from diffusers.loaders import LoraLoaderMixin
10
  import requests
11
  from io import BytesIO
12
 
@@ -16,51 +15,70 @@ HF_TOKEN = os.getenv("HF_TOKEN")
16
 
17
  # Model configuration
18
  KONTEXT_MODEL = "black-forest-labs/FLUX.1-Kontext-dev"
 
19
  LORA_MODEL = "thedeoxen/refcontrol-flux-kontext-reference-pose-lora"
20
  TRIGGER_WORD = "refcontrolpose"
21
 
22
- # Initialize pipeline with authentication
23
  print("Loading models...")
24
 
25
- try:
26
- # Load Flux Kontext pipeline with HF token
27
- if HF_TOKEN:
28
- from diffusers import FluxKontextPipeline
29
- pipe = FluxKontextPipeline.from_pretrained(
30
- KONTEXT_MODEL,
31
- torch_dtype=torch.bfloat16,
32
- use_auth_token=HF_TOKEN
33
- )
34
-
35
- # Load the RefControl LoRA
36
- pipe.load_lora_weights(
37
- LORA_MODEL,
38
- adapter_name="refcontrol",
39
- use_auth_token=HF_TOKEN
40
- )
41
-
42
- # Move to GPU
43
- pipe = pipe.to("cuda")
44
-
45
- MODEL_STATUS = "βœ… Flux Kontext + RefControl LoRA loaded successfully"
46
- print(MODEL_STATUS)
47
-
48
- else:
49
- raise ValueError("HF_TOKEN not found in environment variables")
50
-
51
- except Exception as e:
52
- print(f"Error loading models: {e}")
53
- # Fallback to base model without LoRA
54
  try:
55
- pipe = DiffusionPipeline.from_pretrained(
56
- "black-forest-labs/FLUX.1-dev",
57
- torch_dtype=torch.bfloat16,
58
- use_auth_token=HF_TOKEN if HF_TOKEN else True
59
- ).to("cuda")
60
- MODEL_STATUS = "⚠️ Running in fallback mode (FLUX.1-dev without LoRA)"
61
- except:
62
- MODEL_STATUS = "❌ Failed to load models. Please check HF_TOKEN"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  pipe = None
 
 
 
 
 
 
64
 
65
  def prepare_images_for_kontext(reference_image, pose_image, target_size=768):
66
  """
@@ -140,11 +158,11 @@ def generate_pose_transfer(
140
  progress=gr.Progress(track_tqdm=True)
141
  ):
142
  """
143
- Main generation function using RefControl LoRA.
144
  """
145
 
146
  if pipe is None:
147
- return None, 0, "Model not loaded. Please check HF_TOKEN"
148
 
149
  if reference_image is None or pose_image is None:
150
  raise gr.Error("Please upload both reference and pose images")
@@ -176,23 +194,40 @@ def generate_pose_transfer(
176
  generator = torch.Generator("cuda").manual_seed(seed)
177
 
178
  try:
179
- # Generate with LoRA
 
 
180
  with torch.autocast("cuda"):
181
- if hasattr(pipe, 'set_adapters'):
182
- # Set LoRA adapter strength
183
- pipe.set_adapters(["refcontrol"], adapter_weights=[lora_scale])
 
 
 
184
 
185
- # Generate image
186
- result = pipe(
187
- image=concatenated_input,
188
- prompt=full_prompt,
189
- negative_prompt=negative_prompt,
190
- guidance_scale=guidance_scale,
191
- num_inference_steps=num_inference_steps,
192
- generator=generator,
193
- width=concatenated_input.width,
194
- height=concatenated_input.height,
195
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  return result, seed, concatenated_input
198
 
@@ -215,12 +250,14 @@ css = """
215
  .header h1 {
216
  color: white;
217
  margin: 0;
 
218
  }
219
  .status-box {
220
  padding: 10px;
221
  border-radius: 8px;
222
  margin: 10px 0;
223
  font-weight: bold;
 
224
  }
225
  .input-image {
226
  border: 2px solid #e0e0e0;
@@ -232,40 +269,57 @@ css = """
232
  border-radius: 8px;
233
  overflow: hidden;
234
  }
 
 
 
 
 
 
235
  """
236
 
237
  # Create Gradio interface
238
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
239
 
240
  with gr.Column(elem_id="col-container"):
241
- # Header with authentication
242
- with gr.Row():
243
- with gr.Column():
244
- gr.HTML("""
245
- <div class="header">
246
- <h1>🎭 RefControl Flux Kontext - Reference Pose Transfer</h1>
247
- <p style="color: white;">Powered by thedeoxen/refcontrol-flux-kontext-reference-pose-lora</p>
248
- </div>
249
- """)
250
-
251
- # Model status
252
- gr.Markdown(f"""
253
- <div class="status-box" style="background: {'#d4edda' if 'βœ…' in MODEL_STATUS else '#f8d7da'};">
254
- {MODEL_STATUS}
255
- </div>
256
- """)
257
 
258
- # Authentication info
259
  if not HF_TOKEN:
260
  gr.Markdown("""
261
  ### πŸ” Authentication Required
262
- Please set your Hugging Face token to use this Space:
263
- 1. Go to Settings β†’ Variables and secrets
 
264
  2. Add `HF_TOKEN` with your Hugging Face token
265
  3. Restart the Space
 
 
266
  """)
267
  gr.LoginButton("Sign in with Hugging Face", size="lg")
268
 
 
 
 
 
 
 
 
 
 
269
  # Main interface
270
  with gr.Row():
271
  with gr.Column(scale=1):
@@ -298,15 +352,16 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
298
 
299
  # Prompts
300
  prompt = gr.Textbox(
301
- label=f"Prompt (trigger word '{TRIGGER_WORD}' added automatically)",
302
  placeholder="e.g., wearing elegant dress, professional photography",
303
  lines=2
304
  )
305
 
306
  negative_prompt = gr.Textbox(
307
- label="Negative Prompt",
308
  placeholder="e.g., blurry, low quality, distorted",
309
- lines=2
 
310
  )
311
 
312
  # Generate button
@@ -348,14 +403,24 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
348
  value=28
349
  )
350
 
351
- lora_scale = gr.Slider(
352
- label="LoRA Strength",
353
- minimum=0.0,
354
- maximum=2.0,
355
- step=0.1,
356
- value=1.0,
357
- info="RefControl LoRA influence"
358
- )
 
 
 
 
 
 
 
 
 
 
359
 
360
  enhance_pose = gr.Checkbox(
361
  label="Auto-enhance pose edges",
@@ -363,7 +428,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
363
  )
364
 
365
  with gr.Column(scale=1):
366
- gr.Markdown("### πŸ–ΌοΈ Generated Result")
367
 
368
  # Result image
369
  result_image = gr.Image(
@@ -373,12 +438,11 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
373
  height=500
374
  )
375
 
376
- # Info display
377
- with gr.Row():
378
- seed_used = gr.Number(
379
- label="Seed Used",
380
- interactive=False
381
- )
382
 
383
  # Debug view
384
  with gr.Accordion("πŸ” Debug View", open=False):
@@ -387,39 +451,48 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
387
  height=200
388
  )
389
 
390
- # Reuse buttons
391
  with gr.Row():
392
  reuse_ref_btn = gr.Button("♻️ Use as Reference", size="sm")
393
- reuse_pose_btn = gr.Button("πŸ“ Extract & Use as Pose", size="sm")
394
  clear_btn = gr.Button("πŸ—‘οΈ Clear All", size="sm")
395
 
396
  # Examples
397
  gr.Markdown("### πŸ’‘ Example Prompts")
398
  gr.Examples(
399
  examples=[
400
- ["professional portrait, studio lighting, high quality"],
401
- ["wearing red dress, outdoor garden setting"],
402
- ["business attire, corporate headshot"],
403
  ["casual streetwear, urban background"],
404
- ["athletic wear, dynamic action shot"],
405
- ["elegant evening gown, luxury setting"],
406
  ],
407
  inputs=[prompt]
408
  )
409
 
410
  # Instructions
411
- with gr.Accordion("πŸ“– How to Use", open=False):
412
- gr.Markdown("""
413
- 1. **Upload Reference Image**: The person/subject you want to transform
 
 
414
  2. **Upload Pose Image**: Line art or skeleton pose to follow
415
- 3. **Optional**: Add descriptive prompt for style/setting
416
- 4. **Click Generate**: Wait for the magic to happen!
 
 
 
 
 
 
 
 
 
 
417
 
418
- **Tips:**
419
- - Use clear, high-contrast pose images for best results
420
- - The model preserves identity from reference while following pose
421
- - Adjust LoRA strength to balance identity vs pose adherence
422
- - Higher guidance scale = stricter pose following
423
  """)
424
 
425
  # Event handlers
@@ -459,7 +532,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
459
  )
460
 
461
  clear_btn.click(
462
- fn=lambda: [None, None, "", "", 42, None, None],
463
  outputs=[
464
  reference_image,
465
  pose_image,
 
6
  import spaces
7
  from PIL import Image, ImageOps, ImageFilter
8
  from diffusers import FluxPipeline, DiffusionPipeline
 
9
  import requests
10
  from io import BytesIO
11
 
 
15
 
16
  # Model configuration
17
  KONTEXT_MODEL = "black-forest-labs/FLUX.1-Kontext-dev"
18
+ FALLBACK_MODEL = "black-forest-labs/FLUX.1-dev"
19
  LORA_MODEL = "thedeoxen/refcontrol-flux-kontext-reference-pose-lora"
20
  TRIGGER_WORD = "refcontrolpose"
21
 
22
+ # Initialize pipeline
23
  print("Loading models...")
24
 
25
+ def load_pipeline():
26
+ """Load the appropriate pipeline based on availability"""
27
+ global pipe, MODEL_STATUS
28
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  try:
30
+ # First, try to import necessary libraries
31
+ try:
32
+ from diffusers import FluxKontextPipeline
33
+ import peft
34
+ print("PEFT library found")
35
+ use_kontext = True
36
+ except ImportError:
37
+ print("FluxKontextPipeline or PEFT not available, using fallback")
38
+ use_kontext = False
39
+
40
+ if use_kontext and HF_TOKEN:
41
+ # Try to load Kontext model
42
+ pipe = FluxKontextPipeline.from_pretrained(
43
+ KONTEXT_MODEL,
44
+ torch_dtype=torch.bfloat16,
45
+ token=HF_TOKEN
46
+ )
47
+
48
+ # Try to load LoRA if PEFT is available
49
+ try:
50
+ pipe.load_lora_weights(
51
+ LORA_MODEL,
52
+ adapter_name="refcontrol",
53
+ token=HF_TOKEN
54
+ )
55
+ MODEL_STATUS = "βœ… Flux Kontext + RefControl LoRA loaded"
56
+ except Exception as e:
57
+ print(f"Could not load LoRA: {e}")
58
+ MODEL_STATUS = "⚠️ Flux Kontext loaded (without LoRA - PEFT required)"
59
+
60
+ pipe = pipe.to("cuda")
61
+
62
+ else:
63
+ # Fallback to standard FLUX
64
+ pipe = FluxPipeline.from_pretrained(
65
+ FALLBACK_MODEL,
66
+ torch_dtype=torch.bfloat16,
67
+ token=HF_TOKEN if HF_TOKEN else True
68
+ )
69
+ pipe = pipe.to("cuda")
70
+ MODEL_STATUS = "⚠️ Using FLUX.1-dev (fallback mode)"
71
+
72
+ except Exception as e:
73
+ print(f"Error loading models: {e}")
74
+ MODEL_STATUS = f"❌ Error: {str(e)}"
75
  pipe = None
76
+
77
+ return pipe, MODEL_STATUS
78
+
79
+ # Load the pipeline
80
+ pipe, MODEL_STATUS = load_pipeline()
81
+ print(MODEL_STATUS)
82
 
83
  def prepare_images_for_kontext(reference_image, pose_image, target_size=768):
84
  """
 
158
  progress=gr.Progress(track_tqdm=True)
159
  ):
160
  """
161
+ Main generation function using RefControl approach.
162
  """
163
 
164
  if pipe is None:
165
+ return None, 0, "Model not loaded. Please check HF_TOKEN and restart the Space"
166
 
167
  if reference_image is None or pose_image is None:
168
  raise gr.Error("Please upload both reference and pose images")
 
194
  generator = torch.Generator("cuda").manual_seed(seed)
195
 
196
  try:
197
+ # Check if we have LoRA capabilities
198
+ has_lora = hasattr(pipe, 'set_adapters') and "RefControl" in MODEL_STATUS
199
+
200
  with torch.autocast("cuda"):
201
+ if has_lora:
202
+ # Try to set LoRA adapter strength
203
+ try:
204
+ pipe.set_adapters(["refcontrol"], adapter_weights=[lora_scale])
205
+ except Exception as e:
206
+ print(f"Could not set LoRA adapter: {e}")
207
 
208
+ # Generate image based on pipeline type
209
+ if "Kontext" in MODEL_STATUS:
210
+ # Use Kontext pipeline
211
+ result = pipe(
212
+ image=concatenated_input,
213
+ prompt=full_prompt,
214
+ negative_prompt=negative_prompt if negative_prompt else None,
215
+ guidance_scale=guidance_scale,
216
+ num_inference_steps=num_inference_steps,
217
+ generator=generator,
218
+ width=concatenated_input.width,
219
+ height=concatenated_input.height,
220
+ ).images[0]
221
+ else:
222
+ # Use standard FLUX pipeline (image-to-image)
223
+ result = pipe(
224
+ prompt=full_prompt,
225
+ image=concatenated_input,
226
+ guidance_scale=guidance_scale,
227
+ num_inference_steps=num_inference_steps,
228
+ generator=generator,
229
+ strength=0.85, # For img2img mode
230
+ ).images[0]
231
 
232
  return result, seed, concatenated_input
233
 
 
250
  .header h1 {
251
  color: white;
252
  margin: 0;
253
+ font-size: 2em;
254
  }
255
  .status-box {
256
  padding: 10px;
257
  border-radius: 8px;
258
  margin: 10px 0;
259
  font-weight: bold;
260
+ text-align: center;
261
  }
262
  .input-image {
263
  border: 2px solid #e0e0e0;
 
269
  border-radius: 8px;
270
  overflow: hidden;
271
  }
272
+ .info-box {
273
+ background: #f0f0f0;
274
+ padding: 10px;
275
+ border-radius: 8px;
276
+ margin: 10px 0;
277
+ }
278
  """
279
 
280
  # Create Gradio interface
281
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
282
 
283
  with gr.Column(elem_id="col-container"):
284
+ # Header
285
+ gr.HTML("""
286
+ <div class="header">
287
+ <h1>🎭 FLUX Pose Transfer System</h1>
288
+ <p style="color: white;">Transfer poses while preserving identity</p>
289
+ </div>
290
+ """)
291
+
292
+ # Model status
293
+ status_color = "#d4edda" if "βœ…" in MODEL_STATUS else "#fff3cd" if "⚠️" in MODEL_STATUS else "#f8d7da"
294
+ gr.HTML(f"""
295
+ <div class="status-box" style="background: {status_color};">
296
+ {MODEL_STATUS}
297
+ </div>
298
+ """)
 
299
 
300
+ # Authentication check
301
  if not HF_TOKEN:
302
  gr.Markdown("""
303
  ### πŸ” Authentication Required
304
+
305
+ To use this Space with full features:
306
+ 1. Go to **Settings** β†’ **Variables and secrets**
307
  2. Add `HF_TOKEN` with your Hugging Face token
308
  3. Restart the Space
309
+
310
+ Or click below to sign in:
311
  """)
312
  gr.LoginButton("Sign in with Hugging Face", size="lg")
313
 
314
+ # Info box for PEFT requirement
315
+ if "PEFT required" in MODEL_STATUS:
316
+ gr.HTML("""
317
+ <div class="info-box">
318
+ <b>Note:</b> For full LoRA support, PEFT library is required.
319
+ Add <code>peft</code> to your requirements.txt file.
320
+ </div>
321
+ """)
322
+
323
  # Main interface
324
  with gr.Row():
325
  with gr.Column(scale=1):
 
352
 
353
  # Prompts
354
  prompt = gr.Textbox(
355
+ label=f"Prompt ('{TRIGGER_WORD}' added automatically)",
356
  placeholder="e.g., wearing elegant dress, professional photography",
357
  lines=2
358
  )
359
 
360
  negative_prompt = gr.Textbox(
361
+ label="Negative Prompt (optional)",
362
  placeholder="e.g., blurry, low quality, distorted",
363
+ lines=1,
364
+ value="blurry, low quality, distorted, deformed"
365
  )
366
 
367
  # Generate button
 
403
  value=28
404
  )
405
 
406
+ if "LoRA" in MODEL_STATUS:
407
+ lora_scale = gr.Slider(
408
+ label="LoRA Strength",
409
+ minimum=0.0,
410
+ maximum=2.0,
411
+ step=0.1,
412
+ value=1.0,
413
+ info="RefControl LoRA influence"
414
+ )
415
+ else:
416
+ lora_scale = gr.Slider(
417
+ label="LoRA Strength (not available)",
418
+ minimum=0.0,
419
+ maximum=2.0,
420
+ step=0.1,
421
+ value=1.0,
422
+ interactive=False
423
+ )
424
 
425
  enhance_pose = gr.Checkbox(
426
  label="Auto-enhance pose edges",
 
428
  )
429
 
430
  with gr.Column(scale=1):
431
+ gr.Markdown("### πŸ–ΌοΈ Result")
432
 
433
  # Result image
434
  result_image = gr.Image(
 
438
  height=500
439
  )
440
 
441
+ # Seed display
442
+ seed_used = gr.Number(
443
+ label="Seed Used",
444
+ interactive=False
445
+ )
 
446
 
447
  # Debug view
448
  with gr.Accordion("πŸ” Debug View", open=False):
 
451
  height=200
452
  )
453
 
454
+ # Action buttons
455
  with gr.Row():
456
  reuse_ref_btn = gr.Button("♻️ Use as Reference", size="sm")
457
+ reuse_pose_btn = gr.Button("πŸ“ Extract Pose", size="sm")
458
  clear_btn = gr.Button("πŸ—‘οΈ Clear All", size="sm")
459
 
460
  # Examples
461
  gr.Markdown("### πŸ’‘ Example Prompts")
462
  gr.Examples(
463
  examples=[
464
+ ["professional portrait, studio lighting"],
465
+ ["wearing red dress, outdoor garden"],
466
+ ["business attire, office setting"],
467
  ["casual streetwear, urban background"],
468
+ ["athletic wear, gym environment"],
 
469
  ],
470
  inputs=[prompt]
471
  )
472
 
473
  # Instructions
474
+ with gr.Accordion("πŸ“– Instructions", open=False):
475
+ gr.Markdown(f"""
476
+ ## How to Use:
477
+
478
+ 1. **Upload Reference Image**: The person whose appearance you want to keep
479
  2. **Upload Pose Image**: Line art or skeleton pose to follow
480
+ 3. **Add Prompt** (optional): Describe additional details
481
+ 4. **Click Generate**: Create your pose-transferred image
482
+
483
+ ## Model Information:
484
+ - **Current Model**: {MODEL_STATUS}
485
+ - **Trigger Word**: `{TRIGGER_WORD}` (added automatically)
486
+
487
+ ## Tips:
488
+ - Use clear, high-contrast pose images
489
+ - Black lines on white background work best for poses
490
+ - Adjust guidance scale for pose adherence strength
491
+ - Higher steps = better quality but slower
492
 
493
+ ## Requirements:
494
+ - **HF_TOKEN**: Required for model access
495
+ - **peft**: Required for LoRA support (add to requirements.txt)
 
 
496
  """)
497
 
498
  # Event handlers
 
532
  )
533
 
534
  clear_btn.click(
535
+ fn=lambda: [None, None, "", "blurry, low quality, distorted, deformed", 42, None, None],
536
  outputs=[
537
  reference_image,
538
  pose_image,