Update app.py
Browse files
app.py
CHANGED
|
@@ -3,204 +3,344 @@ import numpy as np
|
|
| 3 |
import spaces
|
| 4 |
import torch
|
| 5 |
import random
|
| 6 |
-
from PIL import Image
|
| 7 |
-
#from kontext_pipeline import FluxKontextPipeline
|
| 8 |
from diffusers import FluxKontextPipeline
|
| 9 |
from diffusers.utils import load_image
|
| 10 |
|
| 11 |
-
# Load Kontext model
|
| 12 |
MAX_SEED = np.iinfo(np.int32).max
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
"""
|
| 18 |
-
|
| 19 |
|
| 20 |
Args:
|
| 21 |
-
|
| 22 |
-
|
| 23 |
|
| 24 |
Returns:
|
| 25 |
-
PIL Image: Concatenated image
|
| 26 |
"""
|
| 27 |
-
if
|
| 28 |
return None
|
| 29 |
|
| 30 |
-
#
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
return None
|
| 35 |
|
| 36 |
-
|
| 37 |
-
return valid_images[0].convert("RGB")
|
| 38 |
|
| 39 |
-
#
|
| 40 |
-
|
| 41 |
|
| 42 |
-
|
| 43 |
-
# Calculate total width and max height
|
| 44 |
-
total_width = sum(img.width for img in valid_images)
|
| 45 |
-
max_height = max(img.height for img in valid_images)
|
| 46 |
-
|
| 47 |
-
# Create new image
|
| 48 |
-
concatenated = Image.new('RGB', (total_width, max_height), (255, 255, 255))
|
| 49 |
-
|
| 50 |
-
# Paste images
|
| 51 |
-
x_offset = 0
|
| 52 |
-
for img in valid_images:
|
| 53 |
-
# Center image vertically if heights differ
|
| 54 |
-
y_offset = (max_height - img.height) // 2
|
| 55 |
-
concatenated.paste(img, (x_offset, y_offset))
|
| 56 |
-
x_offset += img.width
|
| 57 |
-
|
| 58 |
-
else: # vertical
|
| 59 |
-
# Calculate max width and total height
|
| 60 |
-
max_width = max(img.width for img in valid_images)
|
| 61 |
-
total_height = sum(img.height for img in valid_images)
|
| 62 |
-
|
| 63 |
-
# Create new image
|
| 64 |
-
concatenated = Image.new('RGB', (max_width, total_height), (255, 255, 255))
|
| 65 |
-
|
| 66 |
-
# Paste images
|
| 67 |
-
y_offset = 0
|
| 68 |
-
for img in valid_images:
|
| 69 |
-
# Center image horizontally if widths differ
|
| 70 |
-
x_offset = (max_width - img.width) // 2
|
| 71 |
-
concatenated.paste(img, (x_offset, y_offset))
|
| 72 |
-
y_offset += img.height
|
| 73 |
-
|
| 74 |
-
return concatenated
|
| 75 |
|
| 76 |
@spaces.GPU
|
| 77 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
if randomize_seed:
|
| 80 |
seed = random.randint(0, MAX_SEED)
|
| 81 |
|
| 82 |
-
#
|
| 83 |
-
|
| 84 |
-
raise gr.Error("Please upload at least one image.")
|
| 85 |
|
| 86 |
-
#
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
| 89 |
|
| 90 |
-
|
| 91 |
-
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
|
| 102 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
-
#
|
| 105 |
-
#
|
| 106 |
-
# new_height = int(original_height * (new_width / original_width))
|
| 107 |
-
# new_height = round(new_height / 64) * 64
|
| 108 |
-
# else:
|
| 109 |
-
# new_height = 1024
|
| 110 |
-
# new_width = int(original_width * (new_height / original_height))
|
| 111 |
-
# new_width = round(new_width / 64) * 64
|
| 112 |
|
| 113 |
-
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
-
|
|
|
|
| 129 |
#col-container {
|
| 130 |
margin: 0 auto;
|
| 131 |
-
max-width:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
}
|
| 133 |
"""
|
| 134 |
|
|
|
|
| 135 |
with gr.Blocks(css=css) as demo:
|
| 136 |
|
| 137 |
with gr.Column(elem_id="col-container"):
|
| 138 |
-
gr.Markdown(
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
""")
|
|
|
|
| 141 |
with gr.Row():
|
| 142 |
-
with gr.Column():
|
| 143 |
-
|
| 144 |
-
label="Upload image(s) for editing",
|
| 145 |
-
show_label=True,
|
| 146 |
-
elem_id="gallery_input",
|
| 147 |
-
columns=3,
|
| 148 |
-
rows=2,
|
| 149 |
-
object_fit="contain",
|
| 150 |
-
height="auto",
|
| 151 |
-
file_types=['image'],
|
| 152 |
-
type='pil'
|
| 153 |
-
)
|
| 154 |
-
|
| 155 |
-
|
| 156 |
|
| 157 |
with gr.Row():
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
)
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
-
with gr.Accordion("Advanced Settings", open=False):
|
| 169 |
-
|
| 170 |
seed = gr.Slider(
|
| 171 |
label="Seed",
|
| 172 |
minimum=0,
|
| 173 |
maximum=MAX_SEED,
|
| 174 |
step=1,
|
| 175 |
-
value=
|
| 176 |
)
|
| 177 |
|
| 178 |
-
randomize_seed = gr.Checkbox(
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
guidance_scale = gr.Slider(
|
| 181 |
label="Guidance Scale",
|
| 182 |
-
minimum=1,
|
| 183 |
-
maximum=10,
|
| 184 |
-
step=0.
|
| 185 |
-
value=
|
| 186 |
-
|
|
|
|
| 187 |
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
)
|
| 199 |
|
| 200 |
reuse_button.click(
|
| 201 |
-
fn
|
| 202 |
-
inputs
|
| 203 |
-
outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
)
|
| 205 |
|
|
|
|
|
|
|
| 206 |
demo.launch()
|
|
|
|
| 3 |
import spaces
|
| 4 |
import torch
|
| 5 |
import random
|
| 6 |
+
from PIL import Image, ImageOps
|
|
|
|
| 7 |
from diffusers import FluxKontextPipeline
|
| 8 |
from diffusers.utils import load_image
|
| 9 |
|
| 10 |
+
# Load Kontext model with Reference Pose LoRA
|
| 11 |
MAX_SEED = np.iinfo(np.int32).max
|
| 12 |
|
| 13 |
+
# Initialize the pipeline
|
| 14 |
+
pipe = FluxKontextPipeline.from_pretrained(
|
| 15 |
+
"black-forest-labs/FLUX.1-Kontext-dev",
|
| 16 |
+
torch_dtype=torch.bfloat16
|
| 17 |
+
).to("cuda")
|
| 18 |
|
| 19 |
+
# Load the Reference Pose LoRA (if available)
|
| 20 |
+
# Note: You'll need to add the actual LoRA loading code here
|
| 21 |
+
# pipe.load_lora_weights("path/to/refcontrol-pose-lora", adapter_name="refcontrol")
|
| 22 |
+
|
| 23 |
+
def prepare_pose_reference_pair(reference_image, pose_image):
|
| 24 |
"""
|
| 25 |
+
Prepare the reference image and pose control map for Kontext processing.
|
| 26 |
|
| 27 |
Args:
|
| 28 |
+
reference_image: PIL Image - The source image with identity/style to preserve
|
| 29 |
+
pose_image: PIL Image - The pose/line art control map
|
| 30 |
|
| 31 |
Returns:
|
| 32 |
+
PIL Image: Concatenated image with reference on left, pose on right
|
| 33 |
"""
|
| 34 |
+
if reference_image is None or pose_image is None:
|
| 35 |
return None
|
| 36 |
|
| 37 |
+
# Convert images to RGB
|
| 38 |
+
reference_image = reference_image.convert("RGB")
|
| 39 |
+
pose_image = pose_image.convert("RGB")
|
| 40 |
+
|
| 41 |
+
# Resize images to have the same height for better concatenation
|
| 42 |
+
target_height = 768 # Standard height for Flux
|
| 43 |
+
|
| 44 |
+
# Calculate proportional widths
|
| 45 |
+
ref_ratio = reference_image.width / reference_image.height
|
| 46 |
+
pose_ratio = pose_image.width / pose_image.height
|
| 47 |
+
|
| 48 |
+
ref_width = int(target_height * ref_ratio)
|
| 49 |
+
pose_width = int(target_height * pose_ratio)
|
| 50 |
+
|
| 51 |
+
# Ensure dimensions are divisible by 8 (required for Flux)
|
| 52 |
+
ref_width = (ref_width // 8) * 8
|
| 53 |
+
pose_width = (pose_width // 8) * 8
|
| 54 |
+
|
| 55 |
+
# Resize images
|
| 56 |
+
reference_resized = reference_image.resize((ref_width, target_height), Image.LANCZOS)
|
| 57 |
+
pose_resized = pose_image.resize((pose_width, target_height), Image.LANCZOS)
|
| 58 |
+
|
| 59 |
+
# Create concatenated image: reference on left, pose on right
|
| 60 |
+
total_width = ref_width + pose_width
|
| 61 |
+
concatenated = Image.new('RGB', (total_width, target_height), (255, 255, 255))
|
| 62 |
+
|
| 63 |
+
# Paste images
|
| 64 |
+
concatenated.paste(reference_resized, (0, 0))
|
| 65 |
+
concatenated.paste(pose_resized, (ref_width, 0))
|
| 66 |
|
| 67 |
+
return concatenated, ref_width, pose_width
|
| 68 |
+
|
| 69 |
+
def process_pose_image(pose_image):
|
| 70 |
+
"""
|
| 71 |
+
Process the pose image to enhance line art visibility if needed.
|
| 72 |
+
"""
|
| 73 |
+
if pose_image is None:
|
| 74 |
return None
|
| 75 |
|
| 76 |
+
pose_image = pose_image.convert("RGB")
|
|
|
|
| 77 |
|
| 78 |
+
# Optional: Enhance contrast for better pose detection
|
| 79 |
+
# You can add image processing here if the pose needs enhancement
|
| 80 |
|
| 81 |
+
return pose_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
@spaces.GPU
|
| 84 |
+
def infer_pose_transfer(
|
| 85 |
+
reference_image,
|
| 86 |
+
pose_image,
|
| 87 |
+
prompt="",
|
| 88 |
+
seed=42,
|
| 89 |
+
randomize_seed=False,
|
| 90 |
+
guidance_scale=3.5,
|
| 91 |
+
strength=0.85,
|
| 92 |
+
progress=gr.Progress(track_tqdm=True)
|
| 93 |
+
):
|
| 94 |
+
"""
|
| 95 |
+
Transfer pose from control image to reference image using Flux Kontext.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
if reference_image is None or pose_image is None:
|
| 99 |
+
raise gr.Error("Please upload both a reference image and a pose image.")
|
| 100 |
|
| 101 |
if randomize_seed:
|
| 102 |
seed = random.randint(0, MAX_SEED)
|
| 103 |
|
| 104 |
+
# Process pose image if needed
|
| 105 |
+
pose_image = process_pose_image(pose_image)
|
|
|
|
| 106 |
|
| 107 |
+
# Prepare the concatenated input
|
| 108 |
+
concatenated_input, ref_width, pose_width = prepare_pose_reference_pair(
|
| 109 |
+
reference_image,
|
| 110 |
+
pose_image
|
| 111 |
+
)
|
| 112 |
|
| 113 |
+
if concatenated_input is None:
|
| 114 |
+
raise gr.Error("Failed to process the input images.")
|
| 115 |
|
| 116 |
+
# Construct the prompt with the trigger word
|
| 117 |
+
base_prompt = "refcontrolpose"
|
| 118 |
|
| 119 |
+
if prompt:
|
| 120 |
+
# User-provided prompt with trigger word
|
| 121 |
+
full_prompt = f"{base_prompt}, {prompt}"
|
| 122 |
+
else:
|
| 123 |
+
# Default prompt for pose transfer
|
| 124 |
+
full_prompt = f"{base_prompt}, transfer the pose from the right image to the subject in the left image, maintaining the identity, clothing, and style of the original subject while adopting the exact pose and body position shown in the control map"
|
| 125 |
|
| 126 |
+
# Add instruction for the model to understand the layout
|
| 127 |
+
full_prompt += ". The left side shows the reference with identity to preserve, the right side shows the target pose to follow."
|
| 128 |
|
| 129 |
+
# Generate the image
|
| 130 |
+
with torch.autocast("cuda"):
|
| 131 |
+
result = pipe(
|
| 132 |
+
image=concatenated_input,
|
| 133 |
+
prompt=full_prompt,
|
| 134 |
+
guidance_scale=guidance_scale,
|
| 135 |
+
num_inference_steps=28,
|
| 136 |
+
width=concatenated_input.size[0],
|
| 137 |
+
height=concatenated_input.size[1],
|
| 138 |
+
generator=torch.Generator("cuda").manual_seed(seed),
|
| 139 |
+
).images[0]
|
| 140 |
|
| 141 |
+
# Optional: Crop the result to show only the transformed subject
|
| 142 |
+
# You might want to crop out the concatenated input and show only the result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
+
return result, seed, concatenated_input
|
| 145 |
|
| 146 |
+
def create_pose_from_image(image):
|
| 147 |
+
"""
|
| 148 |
+
Helper function to extract pose/line art from an image.
|
| 149 |
+
This is a placeholder - you might want to integrate with OpenPose or similar.
|
| 150 |
+
"""
|
| 151 |
+
if image is None:
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
# Placeholder: In production, you'd use OpenPose or similar
|
| 155 |
+
# For now, we'll just convert to grayscale as a simple edge detection
|
| 156 |
+
from PIL import ImageFilter, ImageOps
|
| 157 |
+
|
| 158 |
+
image = image.convert("L") # Convert to grayscale
|
| 159 |
+
image = image.filter(ImageFilter.FIND_EDGES) # Simple edge detection
|
| 160 |
+
image = ImageOps.invert(image) # Invert to get black lines on white
|
| 161 |
+
image = image.convert("RGB") # Convert back to RGB
|
| 162 |
+
|
| 163 |
+
return image
|
| 164 |
|
| 165 |
+
# CSS styling
|
| 166 |
+
css = """
|
| 167 |
#col-container {
|
| 168 |
margin: 0 auto;
|
| 169 |
+
max-width: 1200px;
|
| 170 |
+
}
|
| 171 |
+
.image-container {
|
| 172 |
+
border: 2px solid #e0e0e0;
|
| 173 |
+
border-radius: 8px;
|
| 174 |
+
padding: 10px;
|
| 175 |
+
background: #f9f9f9;
|
| 176 |
+
}
|
| 177 |
+
.result-container {
|
| 178 |
+
border: 3px solid #4CAF50;
|
| 179 |
+
border-radius: 8px;
|
| 180 |
+
padding: 10px;
|
| 181 |
+
background: #f0f8f0;
|
| 182 |
}
|
| 183 |
"""
|
| 184 |
|
| 185 |
+
# Create Gradio interface
|
| 186 |
with gr.Blocks(css=css) as demo:
|
| 187 |
|
| 188 |
with gr.Column(elem_id="col-container"):
|
| 189 |
+
gr.Markdown("""
|
| 190 |
+
# ๐ญ FLUX.1 Kontext Reference Pose Transfer
|
| 191 |
+
|
| 192 |
+
Transfer any pose to your subject while preserving their identity and style!
|
| 193 |
+
|
| 194 |
+
**How it works:**
|
| 195 |
+
1. Upload a **reference image** (your subject with identity/style to preserve)
|
| 196 |
+
2. Upload a **pose image** (line art or pose skeleton to follow)
|
| 197 |
+
3. The model will generate your subject in the new pose
|
| 198 |
+
|
| 199 |
+
Uses the **refcontrolpose** LoRA for precise pose control.
|
| 200 |
""")
|
| 201 |
+
|
| 202 |
with gr.Row():
|
| 203 |
+
with gr.Column(scale=1):
|
| 204 |
+
gr.Markdown("### ๐ธ Input Images")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
with gr.Row():
|
| 207 |
+
with gr.Column():
|
| 208 |
+
reference_image = gr.Image(
|
| 209 |
+
label="Reference Image (Subject)",
|
| 210 |
+
type="pil",
|
| 211 |
+
elem_classes=["image-container"],
|
| 212 |
+
height=300
|
| 213 |
+
)
|
| 214 |
+
gr.Markdown("*Upload the image with the subject/style to preserve*")
|
| 215 |
+
|
| 216 |
+
with gr.Column():
|
| 217 |
+
pose_image = gr.Image(
|
| 218 |
+
label="Pose Control (Line Art)",
|
| 219 |
+
type="pil",
|
| 220 |
+
elem_classes=["image-container"],
|
| 221 |
+
height=300
|
| 222 |
+
)
|
| 223 |
+
gr.Markdown("*Upload the pose/line art to follow*")
|
| 224 |
+
|
| 225 |
+
# Optional: Add pose extraction tool
|
| 226 |
+
with gr.Accordion("๐ง Extract Pose from Image", open=False):
|
| 227 |
+
source_for_pose = gr.Image(
|
| 228 |
+
label="Source Image for Pose Extraction",
|
| 229 |
+
type="pil",
|
| 230 |
+
height=200
|
| 231 |
)
|
| 232 |
+
extract_pose_btn = gr.Button("Extract Pose", size="sm")
|
| 233 |
+
|
| 234 |
+
prompt = gr.Textbox(
|
| 235 |
+
label="Additional Prompt (Optional)",
|
| 236 |
+
placeholder="e.g., wearing a red dress, in a garden, professional photography",
|
| 237 |
+
info="Add details about the desired output (trigger word 'refcontrolpose' is added automatically)",
|
| 238 |
+
lines=2
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
with gr.Row():
|
| 242 |
+
run_button = gr.Button("๐จ Transfer Pose", variant="primary", scale=2)
|
| 243 |
+
clear_button = gr.Button("๐๏ธ Clear", scale=1)
|
| 244 |
+
|
| 245 |
+
with gr.Accordion("โ๏ธ Advanced Settings", open=False):
|
| 246 |
|
|
|
|
|
|
|
| 247 |
seed = gr.Slider(
|
| 248 |
label="Seed",
|
| 249 |
minimum=0,
|
| 250 |
maximum=MAX_SEED,
|
| 251 |
step=1,
|
| 252 |
+
value=42,
|
| 253 |
)
|
| 254 |
|
| 255 |
+
randomize_seed = gr.Checkbox(
|
| 256 |
+
label="Randomize seed",
|
| 257 |
+
value=True
|
| 258 |
+
)
|
| 259 |
|
| 260 |
guidance_scale = gr.Slider(
|
| 261 |
label="Guidance Scale",
|
| 262 |
+
minimum=1.0,
|
| 263 |
+
maximum=10.0,
|
| 264 |
+
step=0.5,
|
| 265 |
+
value=3.5,
|
| 266 |
+
info="Higher values follow the pose more strictly"
|
| 267 |
+
)
|
| 268 |
|
| 269 |
+
strength = gr.Slider(
|
| 270 |
+
label="Transformation Strength",
|
| 271 |
+
minimum=0.1,
|
| 272 |
+
maximum=1.0,
|
| 273 |
+
step=0.05,
|
| 274 |
+
value=0.85,
|
| 275 |
+
info="How much to change from the original"
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
with gr.Column(scale=1):
|
| 279 |
+
gr.Markdown("### ๐ผ๏ธ Results")
|
| 280 |
+
|
| 281 |
+
result = gr.Image(
|
| 282 |
+
label="Generated Result",
|
| 283 |
+
elem_classes=["result-container"],
|
| 284 |
+
interactive=False,
|
| 285 |
+
height=400
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
with gr.Accordion("๐ Generation Info", open=False):
|
| 289 |
+
used_seed = gr.Number(label="Seed Used", interactive=False)
|
| 290 |
+
input_preview = gr.Image(
|
| 291 |
+
label="Concatenated Input (Reference | Pose)",
|
| 292 |
+
height=200
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
with gr.Row():
|
| 296 |
+
save_button = gr.Button("๐พ Save Result", size="sm")
|
| 297 |
+
reuse_button = gr.Button("โป๏ธ Use as Reference", size="sm")
|
| 298 |
|
| 299 |
+
# Examples
|
| 300 |
+
gr.Markdown("### ๐ก Examples")
|
| 301 |
+
gr.Examples(
|
| 302 |
+
examples=[
|
| 303 |
+
["A person in business attire", "standing confidently"],
|
| 304 |
+
["A dancer in elegant costume", "performing a ballet leap"],
|
| 305 |
+
["An athlete in sportswear", "doing a martial arts kick"],
|
| 306 |
+
["A model in casual outfit", "sitting on a chair"],
|
| 307 |
+
],
|
| 308 |
+
inputs=[prompt],
|
| 309 |
+
label="Example Prompts"
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Event handlers
|
| 313 |
+
run_button.click(
|
| 314 |
+
fn=infer_pose_transfer,
|
| 315 |
+
inputs=[
|
| 316 |
+
reference_image,
|
| 317 |
+
pose_image,
|
| 318 |
+
prompt,
|
| 319 |
+
seed,
|
| 320 |
+
randomize_seed,
|
| 321 |
+
guidance_scale,
|
| 322 |
+
strength
|
| 323 |
+
],
|
| 324 |
+
outputs=[result, used_seed, input_preview]
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
extract_pose_btn.click(
|
| 328 |
+
fn=create_pose_from_image,
|
| 329 |
+
inputs=[source_for_pose],
|
| 330 |
+
outputs=[pose_image]
|
| 331 |
)
|
| 332 |
|
| 333 |
reuse_button.click(
|
| 334 |
+
fn=lambda img: img,
|
| 335 |
+
inputs=[result],
|
| 336 |
+
outputs=[reference_image]
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
clear_button.click(
|
| 340 |
+
fn=lambda: [None, None, "", None, 42, None],
|
| 341 |
+
outputs=[reference_image, pose_image, prompt, result, used_seed, input_preview]
|
| 342 |
)
|
| 343 |
|
| 344 |
+
# Launch the app
|
| 345 |
+
demo.queue()
|
| 346 |
demo.launch()
|