|
|
import spaces |
|
|
import os |
|
|
import torch |
|
|
from diffusers import StableDiffusionXLPipeline |
|
|
import gradio as gr |
|
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
|
from nested_attention_pipeline import NestedAdapterInference, add_special_token_to_tokenizer |
|
|
from utils import align_face |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" |
|
|
image_encoder_path = snapshot_download("orpatashnik/NestedAttentionEncoder", allow_patterns=["image_encoder/**"]) |
|
|
image_encoder_path = os.path.join(image_encoder_path, "image_encoder") |
|
|
personalization_ckpt = hf_hub_download("orpatashnik/NestedAttentionEncoder", "personalization_encoder/pytorch_model.safetensors") |
|
|
device = "cuda" |
|
|
|
|
|
|
|
|
placeholder_token = "<person>" |
|
|
initializer_token = "person" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipe = StableDiffusionXLPipeline.from_pretrained( |
|
|
base_model_path, |
|
|
torch_dtype=torch.float16, |
|
|
) |
|
|
add_special_token_to_tokenizer(pipe, placeholder_token, initializer_token) |
|
|
ip_model = NestedAdapterInference( |
|
|
pipe, |
|
|
image_encoder_path, |
|
|
personalization_ckpt, |
|
|
1024, |
|
|
vq_normalize_factor=2.0, |
|
|
device=device |
|
|
) |
|
|
|
|
|
|
|
|
negative_prompt = "bad anatomy, monochrome, lowres, worst quality, low quality" |
|
|
num_inference_steps = 30 |
|
|
guidance_scale = 5.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def generate_images(img1, img2, img3, prompt, w, num_samples, seed): |
|
|
|
|
|
refs = [img for img in (img1, img2, img3) if img is not None] |
|
|
if not refs: |
|
|
return [] |
|
|
|
|
|
|
|
|
aligned_refs = [align_face(img) for img in refs] |
|
|
|
|
|
|
|
|
pil_images = [aligned.resize((512, 512)) for aligned in aligned_refs] |
|
|
placeholder_token_ids = ip_model.pipe.tokenizer.convert_tokens_to_ids([placeholder_token]) |
|
|
|
|
|
|
|
|
results = ip_model.generate( |
|
|
pil_image=pil_images, |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
num_samples=num_samples, |
|
|
num_inference_steps=num_inference_steps, |
|
|
placeholder_token_ids=placeholder_token_ids, |
|
|
seed=seed if seed > 0 else None, |
|
|
guidance_scale=guidance_scale, |
|
|
multiple_images=True, |
|
|
special_token_weight=w |
|
|
) |
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## Nested Attention: Semantic-aware Attention Values for Concept Personalization") |
|
|
gr.Markdown( |
|
|
"Upload up to 3 reference images. " |
|
|
"Faces will be auto-aligned before personalization. Include the placeholder token (e.g., \\<person\\>) in your prompt, " |
|
|
"set token weight, and choose how many outputs you want." |
|
|
) |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
with gr.Row(): |
|
|
img1 = gr.Image(type="pil", label="Reference Image 1") |
|
|
img2 = gr.Image(type="pil", label="Reference Image 2 (optional)") |
|
|
img3 = gr.Image(type="pil", label="Reference Image 3 (optional)") |
|
|
prompt_input = gr.Textbox(label="Prompt", placeholder="e.g., an abstract pencil drawing of a <person>") |
|
|
w_input = gr.Slider(minimum=1.0, maximum=5.0, step=0.5, value=1.0, label="Special Token Weight (w)") |
|
|
num_samples_input = gr.Slider(minimum=1, maximum=6, step=1, value=4, label="Number of Images to Generate") |
|
|
seed_input = gr.Slider(minimum=-1, maximum=100000, step=1, value=-1, label="Random Seed (use -1 for random and up to 100000)") |
|
|
generate_button = gr.Button("Generate Images") |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["example_images/01.jpg", None, None, "a pop figure of a <person>, she stands on a white background", 2.0, 4, 1], |
|
|
["example_images/01.jpg", None, None, "a watercolor painting of a <person>, closeup", 1.0, 4, 42], |
|
|
["example_images/01.jpg", None, None, "a high quality photo of a <person> as a firefighter", 3.0, 4, 10], |
|
|
], |
|
|
inputs=[img1, img2, img3, prompt_input, w_input, num_samples_input, seed_input], |
|
|
label="Example Prompts" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
output_gallery = gr.Gallery(label="Generated Images", columns=3) |
|
|
|
|
|
generate_button.click( |
|
|
fn=generate_images, |
|
|
inputs=[img1, img2, img3, prompt_input, w_input, num_samples_input, seed_input], |
|
|
outputs=output_gallery |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|