Update app.py
Browse files
app.py
CHANGED
|
@@ -1,261 +1,335 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
-
import spaces
|
| 4 |
import torch
|
| 5 |
import random
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
#
|
| 11 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
# Initialize
|
| 14 |
-
|
| 15 |
-
"black-forest-labs/FLUX.1-Kontext-dev",
|
| 16 |
-
torch_dtype=torch.bfloat16
|
| 17 |
-
).to("cuda")
|
| 18 |
|
| 19 |
-
|
| 20 |
-
#
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
def
|
| 24 |
"""
|
| 25 |
-
Prepare
|
| 26 |
-
|
| 27 |
-
Args:
|
| 28 |
-
reference_image: PIL Image - The source image with identity/style to preserve
|
| 29 |
-
pose_image: PIL Image - The pose/line art control map
|
| 30 |
-
|
| 31 |
-
Returns:
|
| 32 |
-
PIL Image: Concatenated image with reference on left, pose on right
|
| 33 |
"""
|
| 34 |
if reference_image is None or pose_image is None:
|
| 35 |
return None
|
| 36 |
|
| 37 |
-
# Convert
|
| 38 |
reference_image = reference_image.convert("RGB")
|
| 39 |
pose_image = pose_image.convert("RGB")
|
| 40 |
|
| 41 |
-
#
|
| 42 |
-
target_height = 768 # Standard height for Flux
|
| 43 |
-
|
| 44 |
-
# Calculate proportional widths
|
| 45 |
ref_ratio = reference_image.width / reference_image.height
|
| 46 |
pose_ratio = pose_image.width / pose_image.height
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
# Ensure dimensions are divisible by 8 (
|
| 52 |
ref_width = (ref_width // 8) * 8
|
| 53 |
pose_width = (pose_width // 8) * 8
|
|
|
|
| 54 |
|
| 55 |
# Resize images
|
| 56 |
-
reference_resized = reference_image.resize((ref_width,
|
| 57 |
-
pose_resized = pose_image.resize((pose_width,
|
| 58 |
|
| 59 |
-
#
|
| 60 |
total_width = ref_width + pose_width
|
| 61 |
-
concatenated = Image.new('RGB', (total_width,
|
| 62 |
-
|
| 63 |
-
# Paste images
|
| 64 |
concatenated.paste(reference_resized, (0, 0))
|
| 65 |
concatenated.paste(pose_resized, (ref_width, 0))
|
| 66 |
|
| 67 |
-
return concatenated
|
| 68 |
|
| 69 |
-
def
|
| 70 |
"""
|
| 71 |
-
|
| 72 |
"""
|
| 73 |
-
if
|
| 74 |
return None
|
| 75 |
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
#
|
| 79 |
-
|
| 80 |
|
| 81 |
-
|
|
|
|
| 82 |
|
| 83 |
-
@spaces.GPU
|
| 84 |
-
def
|
| 85 |
-
reference_image,
|
| 86 |
-
pose_image,
|
| 87 |
-
prompt="",
|
| 88 |
-
|
| 89 |
-
|
|
|
|
| 90 |
guidance_scale=3.5,
|
| 91 |
-
|
|
|
|
|
|
|
| 92 |
progress=gr.Progress(track_tqdm=True)
|
| 93 |
):
|
| 94 |
"""
|
| 95 |
-
|
| 96 |
"""
|
| 97 |
|
|
|
|
|
|
|
|
|
|
| 98 |
if reference_image is None or pose_image is None:
|
| 99 |
-
raise gr.Error("Please upload both
|
| 100 |
|
|
|
|
| 101 |
if randomize_seed:
|
| 102 |
seed = random.randint(0, MAX_SEED)
|
| 103 |
|
| 104 |
-
#
|
| 105 |
-
|
|
|
|
| 106 |
|
| 107 |
-
# Prepare
|
| 108 |
-
concatenated_input
|
| 109 |
-
reference_image,
|
| 110 |
-
pose_image
|
| 111 |
-
)
|
| 112 |
|
| 113 |
if concatenated_input is None:
|
| 114 |
-
raise gr.Error("Failed to process
|
| 115 |
-
|
| 116 |
-
# Construct the prompt with the trigger word
|
| 117 |
-
base_prompt = "refcontrolpose"
|
| 118 |
|
|
|
|
| 119 |
if prompt:
|
| 120 |
-
|
| 121 |
-
full_prompt = f"{base_prompt}, {prompt}"
|
| 122 |
else:
|
| 123 |
-
|
| 124 |
-
full_prompt = f"{base_prompt}, transfer the pose from the right image to the subject in the left image, maintaining the identity, clothing, and style of the original subject while adopting the exact pose and body position shown in the control map"
|
| 125 |
-
|
| 126 |
-
# Add instruction for the model to understand the layout
|
| 127 |
-
full_prompt += ". The left side shows the reference with identity to preserve, the right side shows the target pose to follow."
|
| 128 |
-
|
| 129 |
-
# Generate the image
|
| 130 |
-
with torch.autocast("cuda"):
|
| 131 |
-
result = pipe(
|
| 132 |
-
image=concatenated_input,
|
| 133 |
-
prompt=full_prompt,
|
| 134 |
-
guidance_scale=guidance_scale,
|
| 135 |
-
num_inference_steps=28,
|
| 136 |
-
width=concatenated_input.size[0],
|
| 137 |
-
height=concatenated_input.size[1],
|
| 138 |
-
generator=torch.Generator("cuda").manual_seed(seed),
|
| 139 |
-
).images[0]
|
| 140 |
|
| 141 |
-
#
|
| 142 |
-
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
def create_pose_from_image(image):
|
| 147 |
-
"""
|
| 148 |
-
Helper function to extract pose/line art from an image.
|
| 149 |
-
This is a placeholder - you might want to integrate with OpenPose or similar.
|
| 150 |
-
"""
|
| 151 |
-
if image is None:
|
| 152 |
-
return None
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
# CSS styling
|
| 166 |
css = """
|
| 167 |
#col-container {
|
| 168 |
margin: 0 auto;
|
| 169 |
-
max-width:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
}
|
| 171 |
-
.image
|
| 172 |
border: 2px solid #e0e0e0;
|
| 173 |
border-radius: 8px;
|
| 174 |
-
|
| 175 |
-
background: #f9f9f9;
|
| 176 |
}
|
| 177 |
-
.result-
|
| 178 |
border: 3px solid #4CAF50;
|
| 179 |
border-radius: 8px;
|
| 180 |
-
|
| 181 |
-
background: #f0f8f0;
|
| 182 |
}
|
| 183 |
"""
|
| 184 |
|
| 185 |
# Create Gradio interface
|
| 186 |
-
with gr.Blocks(css=css) as demo:
|
| 187 |
|
| 188 |
with gr.Column(elem_id="col-container"):
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
-
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
|
|
|
| 202 |
with gr.Row():
|
| 203 |
with gr.Column(scale=1):
|
| 204 |
-
gr.Markdown("###
|
| 205 |
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
)
|
| 214 |
-
gr.Markdown("*Upload the image with the subject/style to preserve*")
|
| 215 |
-
|
| 216 |
-
with gr.Column():
|
| 217 |
-
pose_image = gr.Image(
|
| 218 |
-
label="Pose Control (Line Art)",
|
| 219 |
-
type="pil",
|
| 220 |
-
elem_classes=["image-container"],
|
| 221 |
-
height=300
|
| 222 |
-
)
|
| 223 |
-
gr.Markdown("*Upload the pose/line art to follow*")
|
| 224 |
|
| 225 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
with gr.Accordion("π§ Extract Pose from Image", open=False):
|
| 227 |
-
|
| 228 |
-
label="Source
|
| 229 |
type="pil",
|
| 230 |
height=200
|
| 231 |
)
|
| 232 |
-
|
| 233 |
|
|
|
|
| 234 |
prompt = gr.Textbox(
|
| 235 |
-
label="
|
| 236 |
-
placeholder="e.g., wearing
|
| 237 |
-
info="Add details about the desired output (trigger word 'refcontrolpose' is added automatically)",
|
| 238 |
lines=2
|
| 239 |
)
|
| 240 |
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
|
|
|
|
|
|
| 244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
with gr.Accordion("βοΈ Advanced Settings", open=False):
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
)
|
| 259 |
|
| 260 |
guidance_scale = gr.Slider(
|
| 261 |
label="Guidance Scale",
|
|
@@ -263,84 +337,141 @@ with gr.Blocks(css=css) as demo:
|
|
| 263 |
maximum=10.0,
|
| 264 |
step=0.5,
|
| 265 |
value=3.5,
|
| 266 |
-
info="
|
| 267 |
)
|
| 268 |
|
| 269 |
-
|
| 270 |
-
label="
|
| 271 |
-
minimum=
|
| 272 |
-
maximum=
|
| 273 |
-
step=
|
| 274 |
-
value=
|
| 275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
)
|
| 277 |
|
| 278 |
with gr.Column(scale=1):
|
| 279 |
-
gr.Markdown("### πΌοΈ
|
| 280 |
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
|
|
|
| 284 |
interactive=False,
|
| 285 |
-
height=
|
| 286 |
)
|
| 287 |
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
label="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
height=200
|
| 293 |
)
|
| 294 |
|
|
|
|
| 295 |
with gr.Row():
|
| 296 |
-
|
| 297 |
-
|
|
|
|
| 298 |
|
| 299 |
# Examples
|
| 300 |
-
gr.Markdown("### π‘
|
| 301 |
gr.Examples(
|
| 302 |
examples=[
|
| 303 |
-
["
|
| 304 |
-
["
|
| 305 |
-
["
|
| 306 |
-
["
|
|
|
|
|
|
|
| 307 |
],
|
| 308 |
-
inputs=[prompt]
|
| 309 |
-
label="Example Prompts"
|
| 310 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
# Event handlers
|
| 313 |
-
|
| 314 |
-
fn=
|
| 315 |
inputs=[
|
| 316 |
-
reference_image,
|
| 317 |
-
pose_image,
|
| 318 |
-
prompt,
|
| 319 |
-
|
| 320 |
-
|
|
|
|
| 321 |
guidance_scale,
|
| 322 |
-
|
|
|
|
|
|
|
| 323 |
],
|
| 324 |
-
outputs=[
|
| 325 |
)
|
| 326 |
|
| 327 |
-
|
| 328 |
-
fn=
|
| 329 |
-
inputs=[
|
| 330 |
outputs=[pose_image]
|
| 331 |
)
|
| 332 |
|
| 333 |
-
|
| 334 |
-
fn=lambda
|
| 335 |
-
inputs=[
|
| 336 |
outputs=[reference_image]
|
| 337 |
)
|
| 338 |
|
| 339 |
-
|
| 340 |
-
fn=
|
| 341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
)
|
| 343 |
|
| 344 |
# Launch the app
|
| 345 |
-
|
| 346 |
-
demo.
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
|
|
|
| 3 |
import torch
|
| 4 |
import random
|
| 5 |
+
import os
|
| 6 |
+
import spaces
|
| 7 |
+
from PIL import Image, ImageOps, ImageFilter
|
| 8 |
+
from diffusers import FluxPipeline, DiffusionPipeline
|
| 9 |
+
from diffusers.loaders import LoraLoaderMixin
|
| 10 |
+
import requests
|
| 11 |
+
from io import BytesIO
|
| 12 |
|
| 13 |
+
# Constants
|
| 14 |
MAX_SEED = np.iinfo(np.int32).max
|
| 15 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 16 |
+
|
| 17 |
+
# Model configuration
|
| 18 |
+
KONTEXT_MODEL = "black-forest-labs/FLUX.1-Kontext-dev"
|
| 19 |
+
LORA_MODEL = "thedeoxen/refcontrol-flux-kontext-reference-pose-lora"
|
| 20 |
+
TRIGGER_WORD = "refcontrolpose"
|
| 21 |
|
| 22 |
+
# Initialize pipeline with authentication
|
| 23 |
+
print("Loading models...")
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
try:
|
| 26 |
+
# Load Flux Kontext pipeline with HF token
|
| 27 |
+
if HF_TOKEN:
|
| 28 |
+
from diffusers import FluxKontextPipeline
|
| 29 |
+
pipe = FluxKontextPipeline.from_pretrained(
|
| 30 |
+
KONTEXT_MODEL,
|
| 31 |
+
torch_dtype=torch.bfloat16,
|
| 32 |
+
use_auth_token=HF_TOKEN
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Load the RefControl LoRA
|
| 36 |
+
pipe.load_lora_weights(
|
| 37 |
+
LORA_MODEL,
|
| 38 |
+
adapter_name="refcontrol",
|
| 39 |
+
use_auth_token=HF_TOKEN
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Move to GPU
|
| 43 |
+
pipe = pipe.to("cuda")
|
| 44 |
+
|
| 45 |
+
MODEL_STATUS = "β
Flux Kontext + RefControl LoRA loaded successfully"
|
| 46 |
+
print(MODEL_STATUS)
|
| 47 |
+
|
| 48 |
+
else:
|
| 49 |
+
raise ValueError("HF_TOKEN not found in environment variables")
|
| 50 |
+
|
| 51 |
+
except Exception as e:
|
| 52 |
+
print(f"Error loading models: {e}")
|
| 53 |
+
# Fallback to base model without LoRA
|
| 54 |
+
try:
|
| 55 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 56 |
+
"black-forest-labs/FLUX.1-dev",
|
| 57 |
+
torch_dtype=torch.bfloat16,
|
| 58 |
+
use_auth_token=HF_TOKEN if HF_TOKEN else True
|
| 59 |
+
).to("cuda")
|
| 60 |
+
MODEL_STATUS = "β οΈ Running in fallback mode (FLUX.1-dev without LoRA)"
|
| 61 |
+
except:
|
| 62 |
+
MODEL_STATUS = "β Failed to load models. Please check HF_TOKEN"
|
| 63 |
+
pipe = None
|
| 64 |
|
| 65 |
+
def prepare_images_for_kontext(reference_image, pose_image, target_size=768):
|
| 66 |
"""
|
| 67 |
+
Prepare reference and pose images for Kontext processing.
|
| 68 |
+
Following the RefControl format: reference (left) | pose (right)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
"""
|
| 70 |
if reference_image is None or pose_image is None:
|
| 71 |
return None
|
| 72 |
|
| 73 |
+
# Convert to RGB
|
| 74 |
reference_image = reference_image.convert("RGB")
|
| 75 |
pose_image = pose_image.convert("RGB")
|
| 76 |
|
| 77 |
+
# Calculate dimensions maintaining aspect ratio
|
|
|
|
|
|
|
|
|
|
| 78 |
ref_ratio = reference_image.width / reference_image.height
|
| 79 |
pose_ratio = pose_image.width / pose_image.height
|
| 80 |
|
| 81 |
+
# Set heights to target size
|
| 82 |
+
height = target_size
|
| 83 |
+
ref_width = int(height * ref_ratio)
|
| 84 |
+
pose_width = int(height * pose_ratio)
|
| 85 |
|
| 86 |
+
# Ensure dimensions are divisible by 8 (FLUX requirement)
|
| 87 |
ref_width = (ref_width // 8) * 8
|
| 88 |
pose_width = (pose_width // 8) * 8
|
| 89 |
+
height = (height // 8) * 8
|
| 90 |
|
| 91 |
# Resize images
|
| 92 |
+
reference_resized = reference_image.resize((ref_width, height), Image.LANCZOS)
|
| 93 |
+
pose_resized = pose_image.resize((pose_width, height), Image.LANCZOS)
|
| 94 |
|
| 95 |
+
# Concatenate horizontally: reference | pose
|
| 96 |
total_width = ref_width + pose_width
|
| 97 |
+
concatenated = Image.new('RGB', (total_width, height))
|
|
|
|
|
|
|
| 98 |
concatenated.paste(reference_resized, (0, 0))
|
| 99 |
concatenated.paste(pose_resized, (ref_width, 0))
|
| 100 |
|
| 101 |
+
return concatenated
|
| 102 |
|
| 103 |
+
def extract_pose_edges(image):
|
| 104 |
"""
|
| 105 |
+
Extract edge/pose information from an image.
|
| 106 |
"""
|
| 107 |
+
if image is None:
|
| 108 |
return None
|
| 109 |
|
| 110 |
+
# Convert to grayscale
|
| 111 |
+
gray = image.convert("L")
|
| 112 |
+
|
| 113 |
+
# Apply edge detection
|
| 114 |
+
edges = gray.filter(ImageFilter.FIND_EDGES)
|
| 115 |
+
|
| 116 |
+
# Enhance contrast
|
| 117 |
+
edges = ImageOps.autocontrast(edges)
|
| 118 |
+
|
| 119 |
+
# Invert to get black lines on white
|
| 120 |
+
edges = ImageOps.invert(edges)
|
| 121 |
|
| 122 |
+
# Smooth the result
|
| 123 |
+
edges = edges.filter(ImageFilter.SMOOTH_MORE)
|
| 124 |
|
| 125 |
+
# Convert back to RGB
|
| 126 |
+
return edges.convert("RGB")
|
| 127 |
|
| 128 |
+
@spaces.GPU(duration=60)
|
| 129 |
+
def generate_pose_transfer(
|
| 130 |
+
reference_image,
|
| 131 |
+
pose_image,
|
| 132 |
+
prompt="",
|
| 133 |
+
negative_prompt="",
|
| 134 |
+
seed=42,
|
| 135 |
+
randomize_seed=False,
|
| 136 |
guidance_scale=3.5,
|
| 137 |
+
num_inference_steps=28,
|
| 138 |
+
lora_scale=1.0,
|
| 139 |
+
enhance_pose=False,
|
| 140 |
progress=gr.Progress(track_tqdm=True)
|
| 141 |
):
|
| 142 |
"""
|
| 143 |
+
Main generation function using RefControl LoRA.
|
| 144 |
"""
|
| 145 |
|
| 146 |
+
if pipe is None:
|
| 147 |
+
return None, 0, "Model not loaded. Please check HF_TOKEN"
|
| 148 |
+
|
| 149 |
if reference_image is None or pose_image is None:
|
| 150 |
+
raise gr.Error("Please upload both reference and pose images")
|
| 151 |
|
| 152 |
+
# Randomize seed if requested
|
| 153 |
if randomize_seed:
|
| 154 |
seed = random.randint(0, MAX_SEED)
|
| 155 |
|
| 156 |
+
# Enhance pose if requested
|
| 157 |
+
if enhance_pose:
|
| 158 |
+
pose_image = extract_pose_edges(pose_image)
|
| 159 |
|
| 160 |
+
# Prepare concatenated input
|
| 161 |
+
concatenated_input = prepare_images_for_kontext(reference_image, pose_image)
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
if concatenated_input is None:
|
| 164 |
+
raise gr.Error("Failed to process images")
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
+
# Construct prompt with trigger word
|
| 167 |
if prompt:
|
| 168 |
+
full_prompt = f"{TRIGGER_WORD}, {prompt}"
|
|
|
|
| 169 |
else:
|
| 170 |
+
full_prompt = f"{TRIGGER_WORD}, transfer the pose from the right image to the subject in the left image while maintaining their identity, clothing, and style"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
+
# Add instruction for the model
|
| 173 |
+
full_prompt += ". The left image shows the reference subject, the right image shows the target pose."
|
| 174 |
|
| 175 |
+
# Set generator for reproducibility
|
| 176 |
+
generator = torch.Generator("cuda").manual_seed(seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
+
try:
|
| 179 |
+
# Generate with LoRA
|
| 180 |
+
with torch.autocast("cuda"):
|
| 181 |
+
if hasattr(pipe, 'set_adapters'):
|
| 182 |
+
# Set LoRA adapter strength
|
| 183 |
+
pipe.set_adapters(["refcontrol"], adapter_weights=[lora_scale])
|
| 184 |
+
|
| 185 |
+
# Generate image
|
| 186 |
+
result = pipe(
|
| 187 |
+
image=concatenated_input,
|
| 188 |
+
prompt=full_prompt,
|
| 189 |
+
negative_prompt=negative_prompt,
|
| 190 |
+
guidance_scale=guidance_scale,
|
| 191 |
+
num_inference_steps=num_inference_steps,
|
| 192 |
+
generator=generator,
|
| 193 |
+
width=concatenated_input.width,
|
| 194 |
+
height=concatenated_input.height,
|
| 195 |
+
).images[0]
|
| 196 |
+
|
| 197 |
+
return result, seed, concatenated_input
|
| 198 |
+
|
| 199 |
+
except Exception as e:
|
| 200 |
+
raise gr.Error(f"Generation failed: {str(e)}")
|
| 201 |
|
| 202 |
# CSS styling
|
| 203 |
css = """
|
| 204 |
#col-container {
|
| 205 |
margin: 0 auto;
|
| 206 |
+
max-width: 1280px;
|
| 207 |
+
}
|
| 208 |
+
.header {
|
| 209 |
+
text-align: center;
|
| 210 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 211 |
+
padding: 20px;
|
| 212 |
+
border-radius: 12px;
|
| 213 |
+
margin-bottom: 20px;
|
| 214 |
+
}
|
| 215 |
+
.header h1 {
|
| 216 |
+
color: white;
|
| 217 |
+
margin: 0;
|
| 218 |
+
}
|
| 219 |
+
.status-box {
|
| 220 |
+
padding: 10px;
|
| 221 |
+
border-radius: 8px;
|
| 222 |
+
margin: 10px 0;
|
| 223 |
+
font-weight: bold;
|
| 224 |
}
|
| 225 |
+
.input-image {
|
| 226 |
border: 2px solid #e0e0e0;
|
| 227 |
border-radius: 8px;
|
| 228 |
+
overflow: hidden;
|
|
|
|
| 229 |
}
|
| 230 |
+
.result-image {
|
| 231 |
border: 3px solid #4CAF50;
|
| 232 |
border-radius: 8px;
|
| 233 |
+
overflow: hidden;
|
|
|
|
| 234 |
}
|
| 235 |
"""
|
| 236 |
|
| 237 |
# Create Gradio interface
|
| 238 |
+
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
| 239 |
|
| 240 |
with gr.Column(elem_id="col-container"):
|
| 241 |
+
# Header with authentication
|
| 242 |
+
with gr.Row():
|
| 243 |
+
with gr.Column():
|
| 244 |
+
gr.HTML("""
|
| 245 |
+
<div class="header">
|
| 246 |
+
<h1>π RefControl Flux Kontext - Reference Pose Transfer</h1>
|
| 247 |
+
<p style="color: white;">Powered by thedeoxen/refcontrol-flux-kontext-reference-pose-lora</p>
|
| 248 |
+
</div>
|
| 249 |
+
""")
|
| 250 |
+
|
| 251 |
+
# Model status
|
| 252 |
+
gr.Markdown(f"""
|
| 253 |
+
<div class="status-box" style="background: {'#d4edda' if 'β
' in MODEL_STATUS else '#f8d7da'};">
|
| 254 |
+
{MODEL_STATUS}
|
| 255 |
+
</div>
|
| 256 |
+
""")
|
| 257 |
|
| 258 |
+
# Authentication info
|
| 259 |
+
if not HF_TOKEN:
|
| 260 |
+
gr.Markdown("""
|
| 261 |
+
### π Authentication Required
|
| 262 |
+
Please set your Hugging Face token to use this Space:
|
| 263 |
+
1. Go to Settings β Variables and secrets
|
| 264 |
+
2. Add `HF_TOKEN` with your Hugging Face token
|
| 265 |
+
3. Restart the Space
|
| 266 |
+
""")
|
| 267 |
+
gr.LoginButton("Sign in with Hugging Face", size="lg")
|
| 268 |
|
| 269 |
+
# Main interface
|
| 270 |
with gr.Row():
|
| 271 |
with gr.Column(scale=1):
|
| 272 |
+
gr.Markdown("### π₯ Input Images")
|
| 273 |
|
| 274 |
+
# Reference image
|
| 275 |
+
reference_image = gr.Image(
|
| 276 |
+
label="Reference Image (Subject to transform)",
|
| 277 |
+
type="pil",
|
| 278 |
+
elem_classes=["input-image"],
|
| 279 |
+
height=300
|
| 280 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
+
# Pose image
|
| 283 |
+
pose_image = gr.Image(
|
| 284 |
+
label="Pose Control (Line art or skeleton)",
|
| 285 |
+
type="pil",
|
| 286 |
+
elem_classes=["input-image"],
|
| 287 |
+
height=300
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# Pose extraction tool
|
| 291 |
with gr.Accordion("π§ Extract Pose from Image", open=False):
|
| 292 |
+
extract_source = gr.Image(
|
| 293 |
+
label="Source image for pose extraction",
|
| 294 |
type="pil",
|
| 295 |
height=200
|
| 296 |
)
|
| 297 |
+
extract_btn = gr.Button("Extract Pose", size="sm")
|
| 298 |
|
| 299 |
+
# Prompts
|
| 300 |
prompt = gr.Textbox(
|
| 301 |
+
label=f"Prompt (trigger word '{TRIGGER_WORD}' added automatically)",
|
| 302 |
+
placeholder="e.g., wearing elegant dress, professional photography",
|
|
|
|
| 303 |
lines=2
|
| 304 |
)
|
| 305 |
|
| 306 |
+
negative_prompt = gr.Textbox(
|
| 307 |
+
label="Negative Prompt",
|
| 308 |
+
placeholder="e.g., blurry, low quality, distorted",
|
| 309 |
+
lines=2
|
| 310 |
+
)
|
| 311 |
|
| 312 |
+
# Generate button
|
| 313 |
+
generate_btn = gr.Button(
|
| 314 |
+
"π¨ Generate Pose Transfer",
|
| 315 |
+
variant="primary",
|
| 316 |
+
size="lg"
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# Advanced settings
|
| 320 |
with gr.Accordion("βοΈ Advanced Settings", open=False):
|
| 321 |
+
with gr.Row():
|
| 322 |
+
seed = gr.Slider(
|
| 323 |
+
label="Seed",
|
| 324 |
+
minimum=0,
|
| 325 |
+
maximum=MAX_SEED,
|
| 326 |
+
step=1,
|
| 327 |
+
value=42
|
| 328 |
+
)
|
| 329 |
+
randomize_seed = gr.Checkbox(
|
| 330 |
+
label="Randomize",
|
| 331 |
+
value=True
|
| 332 |
+
)
|
|
|
|
| 333 |
|
| 334 |
guidance_scale = gr.Slider(
|
| 335 |
label="Guidance Scale",
|
|
|
|
| 337 |
maximum=10.0,
|
| 338 |
step=0.5,
|
| 339 |
value=3.5,
|
| 340 |
+
info="How strictly to follow the pose"
|
| 341 |
)
|
| 342 |
|
| 343 |
+
num_inference_steps = gr.Slider(
|
| 344 |
+
label="Inference Steps",
|
| 345 |
+
minimum=20,
|
| 346 |
+
maximum=50,
|
| 347 |
+
step=1,
|
| 348 |
+
value=28
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
lora_scale = gr.Slider(
|
| 352 |
+
label="LoRA Strength",
|
| 353 |
+
minimum=0.0,
|
| 354 |
+
maximum=2.0,
|
| 355 |
+
step=0.1,
|
| 356 |
+
value=1.0,
|
| 357 |
+
info="RefControl LoRA influence"
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
enhance_pose = gr.Checkbox(
|
| 361 |
+
label="Auto-enhance pose edges",
|
| 362 |
+
value=False
|
| 363 |
)
|
| 364 |
|
| 365 |
with gr.Column(scale=1):
|
| 366 |
+
gr.Markdown("### πΌοΈ Generated Result")
|
| 367 |
|
| 368 |
+
# Result image
|
| 369 |
+
result_image = gr.Image(
|
| 370 |
+
label="Generated Image",
|
| 371 |
+
elem_classes=["result-image"],
|
| 372 |
interactive=False,
|
| 373 |
+
height=500
|
| 374 |
)
|
| 375 |
|
| 376 |
+
# Info display
|
| 377 |
+
with gr.Row():
|
| 378 |
+
seed_used = gr.Number(
|
| 379 |
+
label="Seed Used",
|
| 380 |
+
interactive=False
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
# Debug view
|
| 384 |
+
with gr.Accordion("π Debug View", open=False):
|
| 385 |
+
concat_preview = gr.Image(
|
| 386 |
+
label="Input Concatenation (Reference | Pose)",
|
| 387 |
height=200
|
| 388 |
)
|
| 389 |
|
| 390 |
+
# Reuse buttons
|
| 391 |
with gr.Row():
|
| 392 |
+
reuse_ref_btn = gr.Button("β»οΈ Use as Reference", size="sm")
|
| 393 |
+
reuse_pose_btn = gr.Button("π Extract & Use as Pose", size="sm")
|
| 394 |
+
clear_btn = gr.Button("ποΈ Clear All", size="sm")
|
| 395 |
|
| 396 |
# Examples
|
| 397 |
+
gr.Markdown("### π‘ Example Prompts")
|
| 398 |
gr.Examples(
|
| 399 |
examples=[
|
| 400 |
+
["professional portrait, studio lighting, high quality"],
|
| 401 |
+
["wearing red dress, outdoor garden setting"],
|
| 402 |
+
["business attire, corporate headshot"],
|
| 403 |
+
["casual streetwear, urban background"],
|
| 404 |
+
["athletic wear, dynamic action shot"],
|
| 405 |
+
["elegant evening gown, luxury setting"],
|
| 406 |
],
|
| 407 |
+
inputs=[prompt]
|
|
|
|
| 408 |
)
|
| 409 |
+
|
| 410 |
+
# Instructions
|
| 411 |
+
with gr.Accordion("π How to Use", open=False):
|
| 412 |
+
gr.Markdown("""
|
| 413 |
+
1. **Upload Reference Image**: The person/subject you want to transform
|
| 414 |
+
2. **Upload Pose Image**: Line art or skeleton pose to follow
|
| 415 |
+
3. **Optional**: Add descriptive prompt for style/setting
|
| 416 |
+
4. **Click Generate**: Wait for the magic to happen!
|
| 417 |
+
|
| 418 |
+
**Tips:**
|
| 419 |
+
- Use clear, high-contrast pose images for best results
|
| 420 |
+
- The model preserves identity from reference while following pose
|
| 421 |
+
- Adjust LoRA strength to balance identity vs pose adherence
|
| 422 |
+
- Higher guidance scale = stricter pose following
|
| 423 |
+
""")
|
| 424 |
|
| 425 |
# Event handlers
|
| 426 |
+
generate_btn.click(
|
| 427 |
+
fn=generate_pose_transfer,
|
| 428 |
inputs=[
|
| 429 |
+
reference_image,
|
| 430 |
+
pose_image,
|
| 431 |
+
prompt,
|
| 432 |
+
negative_prompt,
|
| 433 |
+
seed,
|
| 434 |
+
randomize_seed,
|
| 435 |
guidance_scale,
|
| 436 |
+
num_inference_steps,
|
| 437 |
+
lora_scale,
|
| 438 |
+
enhance_pose
|
| 439 |
],
|
| 440 |
+
outputs=[result_image, seed_used, concat_preview]
|
| 441 |
)
|
| 442 |
|
| 443 |
+
extract_btn.click(
|
| 444 |
+
fn=extract_pose_edges,
|
| 445 |
+
inputs=[extract_source],
|
| 446 |
outputs=[pose_image]
|
| 447 |
)
|
| 448 |
|
| 449 |
+
reuse_ref_btn.click(
|
| 450 |
+
fn=lambda x: x,
|
| 451 |
+
inputs=[result_image],
|
| 452 |
outputs=[reference_image]
|
| 453 |
)
|
| 454 |
|
| 455 |
+
reuse_pose_btn.click(
|
| 456 |
+
fn=extract_pose_edges,
|
| 457 |
+
inputs=[result_image],
|
| 458 |
+
outputs=[pose_image]
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
clear_btn.click(
|
| 462 |
+
fn=lambda: [None, None, "", "", 42, None, None],
|
| 463 |
+
outputs=[
|
| 464 |
+
reference_image,
|
| 465 |
+
pose_image,
|
| 466 |
+
prompt,
|
| 467 |
+
negative_prompt,
|
| 468 |
+
seed_used,
|
| 469 |
+
result_image,
|
| 470 |
+
concat_preview
|
| 471 |
+
]
|
| 472 |
)
|
| 473 |
|
| 474 |
# Launch the app
|
| 475 |
+
if __name__ == "__main__":
|
| 476 |
+
demo.queue()
|
| 477 |
+
demo.launch()
|