|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torch |
|
|
import random |
|
|
import os |
|
|
import spaces |
|
|
from PIL import Image, ImageOps, ImageFilter |
|
|
from diffusers import FluxPipeline, DiffusionPipeline |
|
|
import requests |
|
|
from io import BytesIO |
|
|
|
|
|
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
|
|
|
KONTEXT_MODEL = "black-forest-labs/FLUX.1-Kontext-dev" |
|
|
FALLBACK_MODEL = "black-forest-labs/FLUX.1-dev" |
|
|
LORA_MODEL = "thedeoxen/refcontrol-flux-kontext-reference-pose-lora" |
|
|
TRIGGER_WORD = "refcontrolpose" |
|
|
|
|
|
|
|
|
print("Loading models...") |
|
|
|
|
|
def load_pipeline(): |
|
|
"""Load the appropriate pipeline based on availability""" |
|
|
global pipe, MODEL_STATUS |
|
|
|
|
|
try: |
|
|
|
|
|
try: |
|
|
from diffusers import FluxKontextPipeline |
|
|
import peft |
|
|
print("PEFT library found") |
|
|
use_kontext = True |
|
|
except ImportError: |
|
|
print("FluxKontextPipeline or PEFT not available, using fallback") |
|
|
use_kontext = False |
|
|
|
|
|
if use_kontext and HF_TOKEN: |
|
|
|
|
|
pipe = FluxKontextPipeline.from_pretrained( |
|
|
KONTEXT_MODEL, |
|
|
torch_dtype=torch.bfloat16, |
|
|
token=HF_TOKEN |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
pipe.load_lora_weights( |
|
|
LORA_MODEL, |
|
|
adapter_name="refcontrol", |
|
|
token=HF_TOKEN |
|
|
) |
|
|
MODEL_STATUS = "β
Flux Kontext + RefControl LoRA loaded" |
|
|
except Exception as e: |
|
|
print(f"Could not load LoRA: {e}") |
|
|
MODEL_STATUS = "β οΈ Flux Kontext loaded (without LoRA - PEFT required)" |
|
|
|
|
|
pipe = pipe.to("cuda") |
|
|
|
|
|
else: |
|
|
|
|
|
pipe = FluxPipeline.from_pretrained( |
|
|
FALLBACK_MODEL, |
|
|
torch_dtype=torch.bfloat16, |
|
|
token=HF_TOKEN if HF_TOKEN else True |
|
|
) |
|
|
pipe = pipe.to("cuda") |
|
|
MODEL_STATUS = "β οΈ Using FLUX.1-dev (fallback mode)" |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading models: {e}") |
|
|
MODEL_STATUS = f"β Error: {str(e)}" |
|
|
pipe = None |
|
|
|
|
|
return pipe, MODEL_STATUS |
|
|
|
|
|
|
|
|
pipe, MODEL_STATUS = load_pipeline() |
|
|
print(MODEL_STATUS) |
|
|
|
|
|
def prepare_images_for_kontext(reference_image, pose_image, target_size=768): |
|
|
""" |
|
|
Prepare reference and pose images for Kontext processing. |
|
|
Following the RefControl format: reference (left) | pose (right) |
|
|
""" |
|
|
if reference_image is None or pose_image is None: |
|
|
return None |
|
|
|
|
|
|
|
|
reference_image = reference_image.convert("RGB") |
|
|
pose_image = pose_image.convert("RGB") |
|
|
|
|
|
|
|
|
ref_ratio = reference_image.width / reference_image.height |
|
|
pose_ratio = pose_image.width / pose_image.height |
|
|
|
|
|
|
|
|
height = target_size |
|
|
ref_width = int(height * ref_ratio) |
|
|
pose_width = int(height * pose_ratio) |
|
|
|
|
|
|
|
|
ref_width = (ref_width // 8) * 8 |
|
|
pose_width = (pose_width // 8) * 8 |
|
|
height = (height // 8) * 8 |
|
|
|
|
|
|
|
|
reference_resized = reference_image.resize((ref_width, height), Image.LANCZOS) |
|
|
pose_resized = pose_image.resize((pose_width, height), Image.LANCZOS) |
|
|
|
|
|
|
|
|
total_width = ref_width + pose_width |
|
|
concatenated = Image.new('RGB', (total_width, height)) |
|
|
concatenated.paste(reference_resized, (0, 0)) |
|
|
concatenated.paste(pose_resized, (ref_width, 0)) |
|
|
|
|
|
return concatenated |
|
|
|
|
|
def extract_pose_edges(image): |
|
|
""" |
|
|
Extract edge/pose information from an image. |
|
|
""" |
|
|
if image is None: |
|
|
return None |
|
|
|
|
|
|
|
|
gray = image.convert("L") |
|
|
|
|
|
|
|
|
edges = gray.filter(ImageFilter.FIND_EDGES) |
|
|
|
|
|
|
|
|
edges = ImageOps.autocontrast(edges) |
|
|
|
|
|
|
|
|
edges = ImageOps.invert(edges) |
|
|
|
|
|
|
|
|
edges = edges.filter(ImageFilter.SMOOTH_MORE) |
|
|
|
|
|
|
|
|
return edges.convert("RGB") |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
|
def generate_pose_transfer( |
|
|
reference_image, |
|
|
pose_image, |
|
|
prompt="", |
|
|
negative_prompt="", |
|
|
seed=42, |
|
|
randomize_seed=False, |
|
|
guidance_scale=3.5, |
|
|
num_inference_steps=28, |
|
|
lora_scale=1.0, |
|
|
enhance_pose=False, |
|
|
progress=gr.Progress(track_tqdm=True) |
|
|
): |
|
|
""" |
|
|
Main generation function using RefControl approach. |
|
|
""" |
|
|
|
|
|
if pipe is None: |
|
|
return None, 0, "Model not loaded. Please check HF_TOKEN and restart the Space" |
|
|
|
|
|
if reference_image is None or pose_image is None: |
|
|
raise gr.Error("Please upload both reference and pose images") |
|
|
|
|
|
|
|
|
if randomize_seed: |
|
|
seed = random.randint(0, MAX_SEED) |
|
|
|
|
|
|
|
|
if enhance_pose: |
|
|
pose_image = extract_pose_edges(pose_image) |
|
|
|
|
|
|
|
|
concatenated_input = prepare_images_for_kontext(reference_image, pose_image) |
|
|
|
|
|
if concatenated_input is None: |
|
|
raise gr.Error("Failed to process images") |
|
|
|
|
|
|
|
|
if prompt: |
|
|
full_prompt = f"{TRIGGER_WORD}, {prompt}" |
|
|
else: |
|
|
full_prompt = f"{TRIGGER_WORD}, transfer the pose from the right image to the subject in the left image while maintaining their identity, clothing, and style" |
|
|
|
|
|
|
|
|
full_prompt += ". The left image shows the reference subject, the right image shows the target pose." |
|
|
|
|
|
|
|
|
generator = torch.Generator("cuda").manual_seed(seed) |
|
|
|
|
|
try: |
|
|
|
|
|
has_lora = hasattr(pipe, 'set_adapters') and "RefControl" in MODEL_STATUS |
|
|
|
|
|
with torch.autocast("cuda"): |
|
|
if has_lora: |
|
|
|
|
|
try: |
|
|
pipe.set_adapters(["refcontrol"], adapter_weights=[lora_scale]) |
|
|
except Exception as e: |
|
|
print(f"Could not set LoRA adapter: {e}") |
|
|
|
|
|
|
|
|
if "Kontext" in MODEL_STATUS: |
|
|
|
|
|
result = pipe( |
|
|
image=concatenated_input, |
|
|
prompt=full_prompt, |
|
|
negative_prompt=negative_prompt if negative_prompt else None, |
|
|
guidance_scale=guidance_scale, |
|
|
num_inference_steps=num_inference_steps, |
|
|
generator=generator, |
|
|
width=concatenated_input.width, |
|
|
height=concatenated_input.height, |
|
|
).images[0] |
|
|
else: |
|
|
|
|
|
result = pipe( |
|
|
prompt=full_prompt, |
|
|
image=concatenated_input, |
|
|
guidance_scale=guidance_scale, |
|
|
num_inference_steps=num_inference_steps, |
|
|
generator=generator, |
|
|
strength=0.85, |
|
|
).images[0] |
|
|
|
|
|
return result, seed, concatenated_input |
|
|
|
|
|
except Exception as e: |
|
|
raise gr.Error(f"Generation failed: {str(e)}") |
|
|
|
|
|
|
|
|
css = """ |
|
|
#col-container { |
|
|
margin: 0 auto; |
|
|
max-width: 1280px; |
|
|
} |
|
|
.header { |
|
|
text-align: center; |
|
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
|
padding: 20px; |
|
|
border-radius: 12px; |
|
|
margin-bottom: 20px; |
|
|
} |
|
|
.header h1 { |
|
|
color: white; |
|
|
margin: 0; |
|
|
font-size: 2em; |
|
|
} |
|
|
.status-box { |
|
|
padding: 10px; |
|
|
border-radius: 8px; |
|
|
margin: 10px 0; |
|
|
font-weight: bold; |
|
|
text-align: center; |
|
|
} |
|
|
.input-image { |
|
|
border: 2px solid #e0e0e0; |
|
|
border-radius: 8px; |
|
|
overflow: hidden; |
|
|
} |
|
|
.result-image { |
|
|
border: 3px solid #4CAF50; |
|
|
border-radius: 8px; |
|
|
overflow: hidden; |
|
|
} |
|
|
.info-box { |
|
|
background: #f0f0f0; |
|
|
padding: 10px; |
|
|
border-radius: 8px; |
|
|
margin: 10px 0; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: |
|
|
|
|
|
with gr.Column(elem_id="col-container"): |
|
|
|
|
|
gr.HTML(""" |
|
|
<div class="header"> |
|
|
<h1>π FLUX Pose Transfer System</h1> |
|
|
<p style="color: white;">Transfer poses while preserving identity</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
status_color = "#d4edda" if "β
" in MODEL_STATUS else "#fff3cd" if "β οΈ" in MODEL_STATUS else "#f8d7da" |
|
|
gr.HTML(f""" |
|
|
<div class="status-box" style="background: {status_color};"> |
|
|
{MODEL_STATUS} |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
if not HF_TOKEN: |
|
|
gr.Markdown(""" |
|
|
### π Authentication Required |
|
|
|
|
|
To use this Space with full features: |
|
|
1. Go to **Settings** β **Variables and secrets** |
|
|
2. Add `HF_TOKEN` with your Hugging Face token |
|
|
3. Restart the Space |
|
|
|
|
|
Or click below to sign in: |
|
|
""") |
|
|
gr.LoginButton("Sign in with Hugging Face", size="lg") |
|
|
|
|
|
|
|
|
if "PEFT required" in MODEL_STATUS: |
|
|
gr.HTML(""" |
|
|
<div class="info-box"> |
|
|
<b>Note:</b> For full LoRA support, PEFT library is required. |
|
|
Add <code>peft</code> to your requirements.txt file. |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### π₯ Input Images") |
|
|
|
|
|
|
|
|
reference_image = gr.Image( |
|
|
label="Reference Image (Subject to transform)", |
|
|
type="pil", |
|
|
elem_classes=["input-image"], |
|
|
height=300 |
|
|
) |
|
|
|
|
|
|
|
|
pose_image = gr.Image( |
|
|
label="Pose Control (Line art or skeleton)", |
|
|
type="pil", |
|
|
elem_classes=["input-image"], |
|
|
height=300 |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("π§ Extract Pose from Image", open=False): |
|
|
extract_source = gr.Image( |
|
|
label="Source image for pose extraction", |
|
|
type="pil", |
|
|
height=200 |
|
|
) |
|
|
extract_btn = gr.Button("Extract Pose", size="sm") |
|
|
|
|
|
|
|
|
prompt = gr.Textbox( |
|
|
label=f"Prompt ('{TRIGGER_WORD}' added automatically)", |
|
|
placeholder="e.g., wearing elegant dress, professional photography", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
negative_prompt = gr.Textbox( |
|
|
label="Negative Prompt (optional)", |
|
|
placeholder="e.g., blurry, low quality, distorted", |
|
|
lines=1, |
|
|
value="blurry, low quality, distorted, deformed" |
|
|
) |
|
|
|
|
|
|
|
|
generate_btn = gr.Button( |
|
|
"π¨ Generate Pose Transfer", |
|
|
variant="primary", |
|
|
size="lg" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("βοΈ Advanced Settings", open=False): |
|
|
with gr.Row(): |
|
|
seed = gr.Slider( |
|
|
label="Seed", |
|
|
minimum=0, |
|
|
maximum=MAX_SEED, |
|
|
step=1, |
|
|
value=42 |
|
|
) |
|
|
randomize_seed = gr.Checkbox( |
|
|
label="Randomize", |
|
|
value=True |
|
|
) |
|
|
|
|
|
guidance_scale = gr.Slider( |
|
|
label="Guidance Scale", |
|
|
minimum=1.0, |
|
|
maximum=10.0, |
|
|
step=0.5, |
|
|
value=3.5, |
|
|
info="How strictly to follow the pose" |
|
|
) |
|
|
|
|
|
num_inference_steps = gr.Slider( |
|
|
label="Inference Steps", |
|
|
minimum=20, |
|
|
maximum=50, |
|
|
step=1, |
|
|
value=28 |
|
|
) |
|
|
|
|
|
if "LoRA" in MODEL_STATUS: |
|
|
lora_scale = gr.Slider( |
|
|
label="LoRA Strength", |
|
|
minimum=0.0, |
|
|
maximum=2.0, |
|
|
step=0.1, |
|
|
value=1.0, |
|
|
info="RefControl LoRA influence" |
|
|
) |
|
|
else: |
|
|
lora_scale = gr.Slider( |
|
|
label="LoRA Strength (not available)", |
|
|
minimum=0.0, |
|
|
maximum=2.0, |
|
|
step=0.1, |
|
|
value=1.0, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
enhance_pose = gr.Checkbox( |
|
|
label="Auto-enhance pose edges", |
|
|
value=False |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### πΌοΈ Result") |
|
|
|
|
|
|
|
|
result_image = gr.Image( |
|
|
label="Generated Image", |
|
|
elem_classes=["result-image"], |
|
|
interactive=False, |
|
|
height=500 |
|
|
) |
|
|
|
|
|
|
|
|
seed_used = gr.Number( |
|
|
label="Seed Used", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("π Debug View", open=False): |
|
|
concat_preview = gr.Image( |
|
|
label="Input Concatenation (Reference | Pose)", |
|
|
height=200 |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
reuse_ref_btn = gr.Button("β»οΈ Use as Reference", size="sm") |
|
|
reuse_pose_btn = gr.Button("π Extract Pose", size="sm") |
|
|
clear_btn = gr.Button("ποΈ Clear All", size="sm") |
|
|
|
|
|
|
|
|
gr.Markdown("### π‘ Example Prompts") |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["professional portrait, studio lighting"], |
|
|
["wearing red dress, outdoor garden"], |
|
|
["business attire, office setting"], |
|
|
["casual streetwear, urban background"], |
|
|
["athletic wear, gym environment"], |
|
|
], |
|
|
inputs=[prompt] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("π Instructions", open=False): |
|
|
gr.Markdown(f""" |
|
|
## How to Use: |
|
|
|
|
|
1. **Upload Reference Image**: The person whose appearance you want to keep |
|
|
2. **Upload Pose Image**: Line art or skeleton pose to follow |
|
|
3. **Add Prompt** (optional): Describe additional details |
|
|
4. **Click Generate**: Create your pose-transferred image |
|
|
|
|
|
## Model Information: |
|
|
- **Current Model**: {MODEL_STATUS} |
|
|
- **Trigger Word**: `{TRIGGER_WORD}` (added automatically) |
|
|
|
|
|
## Tips: |
|
|
- Use clear, high-contrast pose images |
|
|
- Black lines on white background work best for poses |
|
|
- Adjust guidance scale for pose adherence strength |
|
|
- Higher steps = better quality but slower |
|
|
|
|
|
## Requirements: |
|
|
- **HF_TOKEN**: Required for model access |
|
|
- **peft**: Required for LoRA support (add to requirements.txt) |
|
|
""") |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_pose_transfer, |
|
|
inputs=[ |
|
|
reference_image, |
|
|
pose_image, |
|
|
prompt, |
|
|
negative_prompt, |
|
|
seed, |
|
|
randomize_seed, |
|
|
guidance_scale, |
|
|
num_inference_steps, |
|
|
lora_scale, |
|
|
enhance_pose |
|
|
], |
|
|
outputs=[result_image, seed_used, concat_preview] |
|
|
) |
|
|
|
|
|
extract_btn.click( |
|
|
fn=extract_pose_edges, |
|
|
inputs=[extract_source], |
|
|
outputs=[pose_image] |
|
|
) |
|
|
|
|
|
reuse_ref_btn.click( |
|
|
fn=lambda x: x, |
|
|
inputs=[result_image], |
|
|
outputs=[reference_image] |
|
|
) |
|
|
|
|
|
reuse_pose_btn.click( |
|
|
fn=extract_pose_edges, |
|
|
inputs=[result_image], |
|
|
outputs=[pose_image] |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
fn=lambda: [None, None, "", "blurry, low quality, distorted, deformed", 42, None, None], |
|
|
outputs=[ |
|
|
reference_image, |
|
|
pose_image, |
|
|
prompt, |
|
|
negative_prompt, |
|
|
seed_used, |
|
|
result_image, |
|
|
concat_preview |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue() |
|
|
demo.launch() |