|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torch |
|
|
import random |
|
|
import os |
|
|
import spaces |
|
|
from PIL import Image, ImageOps, ImageFilter |
|
|
from diffusers import FluxPipeline, DiffusionPipeline |
|
|
import requests |
|
|
from io import BytesIO |
|
|
|
|
|
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
|
|
|
KONTEXT_MODEL = "black-forest-labs/FLUX.1-Kontext-dev" |
|
|
FALLBACK_MODEL = "black-forest-labs/FLUX.1-dev" |
|
|
LORA_MODEL = "thedeoxen/refcontrol-flux-kontext-reference-pose-lora" |
|
|
TRIGGER_WORD = "refcontrolpose" |
|
|
|
|
|
|
|
|
print("Loading models...") |
|
|
|
|
|
def load_pipeline(): |
|
|
"""Load the appropriate pipeline based on availability""" |
|
|
global pipe, MODEL_STATUS |
|
|
|
|
|
try: |
|
|
|
|
|
try: |
|
|
from diffusers import FluxKontextPipeline |
|
|
import peft |
|
|
print("PEFT library found") |
|
|
use_kontext = True |
|
|
except ImportError: |
|
|
print("FluxKontextPipeline or PEFT not available, using fallback") |
|
|
use_kontext = False |
|
|
|
|
|
if use_kontext and HF_TOKEN: |
|
|
|
|
|
pipe = FluxKontextPipeline.from_pretrained( |
|
|
KONTEXT_MODEL, |
|
|
torch_dtype=torch.bfloat16, |
|
|
token=HF_TOKEN |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
pipe.load_lora_weights( |
|
|
LORA_MODEL, |
|
|
adapter_name="refcontrol", |
|
|
token=HF_TOKEN |
|
|
) |
|
|
MODEL_STATUS = "β
Flux Kontext + RefControl LoRA loaded" |
|
|
except Exception as e: |
|
|
print(f"Could not load LoRA: {e}") |
|
|
MODEL_STATUS = "β οΈ Flux Kontext loaded (without LoRA - PEFT required)" |
|
|
|
|
|
pipe = pipe.to("cuda") |
|
|
|
|
|
else: |
|
|
|
|
|
pipe = FluxPipeline.from_pretrained( |
|
|
FALLBACK_MODEL, |
|
|
torch_dtype=torch.bfloat16, |
|
|
token=HF_TOKEN if HF_TOKEN else True |
|
|
) |
|
|
pipe = pipe.to("cuda") |
|
|
MODEL_STATUS = "β οΈ Using FLUX.1-dev (fallback mode)" |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading models: {e}") |
|
|
MODEL_STATUS = f"β Error: {str(e)}" |
|
|
pipe = None |
|
|
|
|
|
return pipe, MODEL_STATUS |
|
|
|
|
|
|
|
|
pipe, MODEL_STATUS = load_pipeline() |
|
|
print(MODEL_STATUS) |
|
|
|
|
|
def prepare_images_for_kontext(reference_image, pose_image, target_size=512): |
|
|
""" |
|
|
Prepare reference and pose images for Kontext processing. |
|
|
Following the RefControl format: reference (left) | pose (right) |
|
|
""" |
|
|
if reference_image is None or pose_image is None: |
|
|
return None |
|
|
|
|
|
|
|
|
reference_image = reference_image.convert("RGB") |
|
|
pose_image = pose_image.convert("RGB") |
|
|
|
|
|
|
|
|
ref_ratio = reference_image.width / reference_image.height |
|
|
pose_ratio = pose_image.width / pose_image.height |
|
|
|
|
|
|
|
|
height = target_size |
|
|
ref_width = int(height * ref_ratio) |
|
|
pose_width = int(height * pose_ratio) |
|
|
|
|
|
|
|
|
ref_width = (ref_width // 8) * 8 |
|
|
pose_width = (pose_width // 8) * 8 |
|
|
height = (height // 8) * 8 |
|
|
|
|
|
|
|
|
reference_resized = reference_image.resize((ref_width, height), Image.LANCZOS) |
|
|
pose_resized = pose_image.resize((pose_width, height), Image.LANCZOS) |
|
|
|
|
|
|
|
|
total_width = ref_width + pose_width |
|
|
concatenated = Image.new('RGB', (total_width, height)) |
|
|
concatenated.paste(reference_resized, (0, 0)) |
|
|
concatenated.paste(pose_resized, (ref_width, 0)) |
|
|
|
|
|
return concatenated |
|
|
|
|
|
def process_pose_for_control(pose_image): |
|
|
""" |
|
|
Process pose image to ensure maximum contrast and clarity for control |
|
|
""" |
|
|
if pose_image is None: |
|
|
return None |
|
|
|
|
|
|
|
|
gray = pose_image.convert("L") |
|
|
|
|
|
|
|
|
edges = gray.filter(ImageFilter.FIND_EDGES) |
|
|
edges = edges.filter(ImageFilter.EDGE_ENHANCE_MORE) |
|
|
|
|
|
|
|
|
edges = ImageOps.autocontrast(edges, cutoff=2) |
|
|
|
|
|
|
|
|
threshold = 128 |
|
|
edges = edges.point(lambda x: 255 if x > threshold else 0, mode='1') |
|
|
|
|
|
|
|
|
edges = edges.convert("RGB") |
|
|
edges = ImageOps.invert(edges) |
|
|
|
|
|
return edges |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
|
def generate_pose_transfer( |
|
|
reference_image, |
|
|
pose_image, |
|
|
prompt="", |
|
|
negative_prompt="", |
|
|
seed=42, |
|
|
randomize_seed=False, |
|
|
guidance_scale=7.5, |
|
|
num_inference_steps=28, |
|
|
lora_scale=1.0, |
|
|
enhance_pose=False, |
|
|
progress=gr.Progress(track_tqdm=True) |
|
|
): |
|
|
""" |
|
|
Main generation function using RefControl approach. |
|
|
""" |
|
|
|
|
|
if pipe is None: |
|
|
return None, 0, "Model not loaded. Please check HF_TOKEN and restart the Space" |
|
|
|
|
|
if reference_image is None or pose_image is None: |
|
|
raise gr.Error("Please upload both reference and pose images") |
|
|
|
|
|
|
|
|
if randomize_seed: |
|
|
seed = random.randint(0, MAX_SEED) |
|
|
|
|
|
|
|
|
if enhance_pose: |
|
|
pose_image = process_pose_for_control(pose_image) |
|
|
|
|
|
|
|
|
concatenated_input = prepare_images_for_kontext(reference_image, pose_image, target_size=512) |
|
|
|
|
|
if concatenated_input is None: |
|
|
raise gr.Error("Failed to process images") |
|
|
|
|
|
|
|
|
width, height = concatenated_input.size |
|
|
|
|
|
width = (width // 64) * 64 |
|
|
height = (height // 64) * 64 |
|
|
|
|
|
|
|
|
max_size = 1024 |
|
|
if width > max_size: |
|
|
ratio = max_size / width |
|
|
width = max_size |
|
|
height = int(height * ratio) |
|
|
height = (height // 64) * 64 |
|
|
|
|
|
if height > max_size: |
|
|
ratio = max_size / height |
|
|
height = max_size |
|
|
width = int(width * ratio) |
|
|
width = (width // 64) * 64 |
|
|
|
|
|
|
|
|
if (width, height) != concatenated_input.size: |
|
|
concatenated_input = concatenated_input.resize((width, height), Image.LANCZOS) |
|
|
|
|
|
|
|
|
|
|
|
base_instruction = f"{TRIGGER_WORD}, A photo composed of two images side by side. Left: reference person. Right: target pose skeleton. Task: Generate the person from the left image in the exact pose shown in the right image" |
|
|
|
|
|
if prompt: |
|
|
full_prompt = f"{base_instruction}. Additional details: {prompt}" |
|
|
else: |
|
|
full_prompt = base_instruction |
|
|
|
|
|
|
|
|
full_prompt += ". IMPORTANT: Strictly follow the pose/skeleton from the right image while preserving the identity, clothing, and appearance from the left image. The output should show ONLY the transformed person, not the side-by-side layout." |
|
|
|
|
|
|
|
|
generator = torch.Generator("cuda").manual_seed(seed) |
|
|
|
|
|
try: |
|
|
|
|
|
has_lora = hasattr(pipe, 'set_adapters') and "LoRA" in MODEL_STATUS |
|
|
|
|
|
|
|
|
if has_lora: |
|
|
try: |
|
|
|
|
|
actual_lora_scale = lora_scale * 1.5 |
|
|
pipe.set_adapters(["refcontrol"], adapter_weights=[actual_lora_scale]) |
|
|
print(f"LoRA adapter set with boosted strength: {actual_lora_scale}") |
|
|
except Exception as e: |
|
|
print(f"LoRA adapter not set: {e}") |
|
|
|
|
|
print(f"Generating with size: {width}x{height}") |
|
|
print(f"Prompt: {full_prompt[:200]}...") |
|
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
|
|
if "Kontext" in MODEL_STATUS: |
|
|
|
|
|
result = pipe( |
|
|
image=concatenated_input, |
|
|
prompt=full_prompt, |
|
|
negative_prompt=negative_prompt if negative_prompt else "blurry, distorted, deformed, wrong pose, incorrect posture", |
|
|
guidance_scale=guidance_scale, |
|
|
num_inference_steps=num_inference_steps, |
|
|
generator=generator, |
|
|
width=width, |
|
|
height=height, |
|
|
).images[0] |
|
|
else: |
|
|
|
|
|
result = pipe( |
|
|
prompt=full_prompt, |
|
|
negative_prompt=negative_prompt if negative_prompt else "", |
|
|
image=concatenated_input, |
|
|
guidance_scale=guidance_scale, |
|
|
num_inference_steps=num_inference_steps, |
|
|
generator=generator, |
|
|
strength=0.85, |
|
|
).images[0] |
|
|
|
|
|
print("Generation successful!") |
|
|
return result, seed, concatenated_input |
|
|
|
|
|
except RuntimeError as e: |
|
|
if "out of memory" in str(e).lower(): |
|
|
raise gr.Error("GPU out of memory. Try reducing image size or inference steps.") |
|
|
else: |
|
|
raise gr.Error(f"Generation failed: {str(e)}") |
|
|
except Exception as e: |
|
|
print(f"Error details: {e}") |
|
|
raise gr.Error(f"Generation failed: {str(e)}") |
|
|
|
|
|
|
|
|
css = """ |
|
|
#col-container { |
|
|
margin: 0 auto; |
|
|
max-width: 1280px; |
|
|
} |
|
|
.header { |
|
|
text-align: center; |
|
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
|
padding: 20px; |
|
|
border-radius: 12px; |
|
|
margin-bottom: 20px; |
|
|
} |
|
|
.header h1 { |
|
|
color: white; |
|
|
margin: 0; |
|
|
font-size: 2em; |
|
|
} |
|
|
.status-box { |
|
|
padding: 10px; |
|
|
border-radius: 8px; |
|
|
margin: 10px 0; |
|
|
font-weight: bold; |
|
|
text-align: center; |
|
|
} |
|
|
.input-image { |
|
|
border: 2px solid #e0e0e0; |
|
|
border-radius: 8px; |
|
|
overflow: hidden; |
|
|
} |
|
|
.result-image { |
|
|
border: 3px solid #4CAF50; |
|
|
border-radius: 8px; |
|
|
overflow: hidden; |
|
|
} |
|
|
.info-box { |
|
|
background: #f0f0f0; |
|
|
padding: 10px; |
|
|
border-radius: 8px; |
|
|
margin: 10px 0; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: |
|
|
|
|
|
with gr.Column(elem_id="col-container"): |
|
|
|
|
|
gr.HTML(""" |
|
|
<div class="header"> |
|
|
<h1>π FLUX Pose Transfer System</h1> |
|
|
<p style="color: white;">Transfer poses while preserving identity</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
status_color = "#d4edda" if "β
" in MODEL_STATUS else "#fff3cd" if "β οΈ" in MODEL_STATUS else "#f8d7da" |
|
|
gr.HTML(f""" |
|
|
<div class="status-box" style="background: {status_color};"> |
|
|
{MODEL_STATUS} |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
if not HF_TOKEN: |
|
|
gr.Markdown(""" |
|
|
### π Authentication Required |
|
|
|
|
|
To use this Space with full features: |
|
|
1. Go to **Settings** β **Variables and secrets** |
|
|
2. Add `HF_TOKEN` with your Hugging Face token |
|
|
3. Restart the Space |
|
|
|
|
|
Or click below to sign in: |
|
|
""") |
|
|
gr.LoginButton("Sign in with Hugging Face", size="lg") |
|
|
|
|
|
|
|
|
if "PEFT required" in MODEL_STATUS: |
|
|
gr.HTML(""" |
|
|
<div class="info-box"> |
|
|
<b>Note:</b> For full LoRA support, PEFT library is required. |
|
|
Add <code>peft</code> to your requirements.txt file. |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### π₯ Input Images") |
|
|
|
|
|
|
|
|
reference_image = gr.Image( |
|
|
label="Reference Image (Subject to transform)", |
|
|
type="pil", |
|
|
elem_classes=["input-image"], |
|
|
height=300 |
|
|
) |
|
|
|
|
|
|
|
|
pose_image = gr.Image( |
|
|
label="Pose Control (Line art or skeleton)", |
|
|
type="pil", |
|
|
elem_classes=["input-image"], |
|
|
height=300 |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("π§ Extract Pose from Image", open=False): |
|
|
extract_source = gr.Image( |
|
|
label="Source image for pose extraction", |
|
|
type="pil", |
|
|
height=200 |
|
|
) |
|
|
extract_btn = gr.Button("Extract Pose", size="sm") |
|
|
|
|
|
|
|
|
prompt = gr.Textbox( |
|
|
label=f"Prompt ('{TRIGGER_WORD}' added automatically)", |
|
|
placeholder="e.g., wearing elegant dress, professional photography", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
negative_prompt = gr.Textbox( |
|
|
label="Negative Prompt (optional)", |
|
|
placeholder="e.g., blurry, low quality, distorted", |
|
|
lines=1, |
|
|
value="blurry, low quality, distorted, deformed" |
|
|
) |
|
|
|
|
|
|
|
|
generate_btn = gr.Button( |
|
|
"π¨ Generate Pose Transfer", |
|
|
variant="primary", |
|
|
size="lg" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("βοΈ Advanced Settings", open=False): |
|
|
with gr.Row(): |
|
|
seed = gr.Slider( |
|
|
label="Seed", |
|
|
minimum=0, |
|
|
maximum=MAX_SEED, |
|
|
step=1, |
|
|
value=42 |
|
|
) |
|
|
randomize_seed = gr.Checkbox( |
|
|
label="Randomize", |
|
|
value=True |
|
|
) |
|
|
|
|
|
guidance_scale = gr.Slider( |
|
|
label="Guidance Scale", |
|
|
minimum=5.0, |
|
|
maximum=15.0, |
|
|
step=0.5, |
|
|
value=7.5, |
|
|
info="Higher = stricter pose following (7-10 recommended)" |
|
|
) |
|
|
|
|
|
num_inference_steps = gr.Slider( |
|
|
label="Inference Steps", |
|
|
minimum=20, |
|
|
maximum=50, |
|
|
step=1, |
|
|
value=30 |
|
|
) |
|
|
|
|
|
if "LoRA" in MODEL_STATUS: |
|
|
lora_scale = gr.Slider( |
|
|
label="LoRA Strength", |
|
|
minimum=0.5, |
|
|
maximum=2.0, |
|
|
step=0.1, |
|
|
value=1.2, |
|
|
info="RefControl LoRA influence (1.0-1.5 recommended)" |
|
|
) |
|
|
else: |
|
|
lora_scale = gr.Slider( |
|
|
label="LoRA Strength (not available)", |
|
|
minimum=0.0, |
|
|
maximum=2.0, |
|
|
step=0.1, |
|
|
value=1.0, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
enhance_pose = gr.Checkbox( |
|
|
label="Auto-enhance pose edges", |
|
|
value=False |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### πΌοΈ Result") |
|
|
|
|
|
|
|
|
result_image = gr.Image( |
|
|
label="Generated Image", |
|
|
elem_classes=["result-image"], |
|
|
interactive=False, |
|
|
height=500 |
|
|
) |
|
|
|
|
|
|
|
|
seed_used = gr.Number( |
|
|
label="Seed Used", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("π Debug View", open=False): |
|
|
concat_preview = gr.Image( |
|
|
label="Input Concatenation (Reference | Pose)", |
|
|
height=200 |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
reuse_ref_btn = gr.Button("β»οΈ Use as Reference", size="sm") |
|
|
reuse_pose_btn = gr.Button("π Extract Pose", size="sm") |
|
|
clear_btn = gr.Button("ποΈ Clear All", size="sm") |
|
|
|
|
|
|
|
|
gr.Markdown("### π‘ Example Prompts") |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["professional portrait, studio lighting"], |
|
|
["wearing red dress, outdoor garden"], |
|
|
["business attire, office setting"], |
|
|
["casual streetwear, urban background"], |
|
|
["athletic wear, gym environment"], |
|
|
], |
|
|
inputs=[prompt] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("π Instructions", open=False): |
|
|
gr.Markdown(f""" |
|
|
## How to Use: |
|
|
|
|
|
1. **Upload Reference Image**: The person whose appearance you want to keep |
|
|
2. **Upload Pose Image**: Line art or skeleton pose to follow |
|
|
3. **Add Prompt** (optional): Describe additional details |
|
|
4. **Click Generate**: Create your pose-transferred image |
|
|
|
|
|
## Model Information: |
|
|
- **Current Model**: {MODEL_STATUS} |
|
|
- **Trigger Word**: `{TRIGGER_WORD}` (added automatically) |
|
|
|
|
|
## Tips: |
|
|
- Use clear, high-contrast pose images |
|
|
- Black lines on white background work best for poses |
|
|
- Adjust guidance scale for pose adherence strength |
|
|
- Higher steps = better quality but slower |
|
|
|
|
|
## Requirements: |
|
|
- **HF_TOKEN**: Required for model access |
|
|
- **peft**: Required for LoRA support (add to requirements.txt) |
|
|
""") |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_pose_transfer, |
|
|
inputs=[ |
|
|
reference_image, |
|
|
pose_image, |
|
|
prompt, |
|
|
negative_prompt, |
|
|
seed, |
|
|
randomize_seed, |
|
|
guidance_scale, |
|
|
num_inference_steps, |
|
|
lora_scale, |
|
|
enhance_pose |
|
|
], |
|
|
outputs=[result_image, seed_used, concat_preview] |
|
|
) |
|
|
|
|
|
extract_btn.click( |
|
|
fn=process_pose_for_control, |
|
|
inputs=[extract_source], |
|
|
outputs=[pose_image] |
|
|
) |
|
|
|
|
|
reuse_ref_btn.click( |
|
|
fn=lambda x: x, |
|
|
inputs=[result_image], |
|
|
outputs=[reference_image] |
|
|
) |
|
|
|
|
|
reuse_pose_btn.click( |
|
|
fn=process_pose_for_control, |
|
|
inputs=[result_image], |
|
|
outputs=[pose_image] |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
fn=lambda: [None, None, "", "blurry, low quality, distorted, deformed", 42, None, None], |
|
|
outputs=[ |
|
|
reference_image, |
|
|
pose_image, |
|
|
prompt, |
|
|
negative_prompt, |
|
|
seed_used, |
|
|
result_image, |
|
|
concat_preview |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue() |
|
|
demo.launch() |