Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import os | |
| import torch | |
| from diffusers import StableDiffusionXLPipeline | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| from nested_attention_pipeline import NestedAdapterInference, add_special_token_to_tokenizer | |
| from utils import align_face | |
| # ---------------------- | |
| # Configuration (update paths as needed) | |
| # ---------------------- | |
| base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" | |
| image_encoder_path = snapshot_download("orpatashnik/NestedAttentionEncoder", allow_patterns=["image_encoder/**"]) | |
| image_encoder_path = os.path.join(image_encoder_path, "image_encoder") | |
| personalization_ckpt = hf_hub_download("orpatashnik/NestedAttentionEncoder", "personalization_encoder/model.safetensors") | |
| device = "cuda" | |
| # Special token settings | |
| placeholder_token = "<person>" | |
| initializer_token = "person" | |
| # ---------------------- | |
| # Load models | |
| # ---------------------- | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| base_model_path, | |
| torch_dtype=torch.float16, | |
| ) | |
| add_special_token_to_tokenizer(pipe, placeholder_token, initializer_token) | |
| ip_model = NestedAdapterInference( | |
| pipe, | |
| image_encoder_path, | |
| personalization_ckpt, | |
| 1024, | |
| vq_normalize_factor=2.0, | |
| device=device | |
| ) | |
| # Generation defaults | |
| negative_prompt = "bad anatomy, monochrome, lowres, worst quality, low quality" | |
| num_inference_steps = 30 | |
| guidance_scale = 5.0 | |
| # ---------------------- | |
| # Inference function with alignment | |
| # ---------------------- | |
| def generate_images(img1, img2, img3, prompt, w, num_samples, seed): | |
| # Collect non-empty reference images | |
| refs = [img for img in (img1, img2, img3) if img is not None] | |
| if not refs: | |
| return [] | |
| # Align directly on PIL | |
| aligned_refs = [align_face(img) for img in refs] | |
| # Resize to model resolution | |
| pil_images = [aligned.resize((512, 512)) for aligned in aligned_refs] | |
| placeholder_token_ids = ip_model.pipe.tokenizer.convert_tokens_to_ids([placeholder_token]) | |
| # Generate personalized samples | |
| results = ip_model.generate( | |
| pil_image=pil_images, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_samples=num_samples, | |
| num_inference_steps=num_inference_steps, | |
| placeholder_token_ids=placeholder_token_ids, | |
| seed=seed if seed > 0 else None, | |
| guidance_scale=guidance_scale, | |
| multiple_images=True, | |
| special_token_weight=w | |
| ) | |
| return results | |
| # ---------------------- | |
| # Gradio UI | |
| # ---------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Nested Attention: Semantic-aware Attention Values for Concept Personalization") | |
| gr.Markdown( | |
| "Upload up to 3 reference images. " | |
| "Faces will be auto-aligned before personalization. Include the placeholder token (e.g., \\<person\\>) in your prompt, " | |
| "set token weight, and choose how many outputs you want." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Reference images | |
| with gr.Row(): | |
| img1 = gr.Image(type="pil", label="Reference Image 1") | |
| img2 = gr.Image(type="pil", label="Reference Image 2 (optional)") | |
| img3 = gr.Image(type="pil", label="Reference Image 3 (optional)") | |
| prompt_input = gr.Textbox(label="Prompt", placeholder="e.g., an abstract pencil drawing of a <person>") | |
| w_input = gr.Slider(minimum=1.0, maximum=5.0, step=0.5, value=1.0, label="Special Token Weight (w)") | |
| num_samples_input = gr.Slider(minimum=1, maximum=6, step=1, value=4, label="Number of Images to Generate") | |
| seed_input = gr.Slider(minimum=-1, maximum=100000, step=1, value=-1, label="Random Seed (use -1 for random and up to 100000)") | |
| generate_button = gr.Button("Generate Images") | |
| # Add examples | |
| gr.Examples( | |
| examples=[ | |
| ["example_images/01.jpg", None, None, "a watercolor painting of a <person>, closeup", 1.0, 4, 1], | |
| ["example_images/02.jpg", None, None, "an abstract pencil drawing of a <person>", 1.5, 4, 30], | |
| ["example_images/01.jpg", None, None, "a high quality photo of a <person> as a firefighter", 3.0, 4, 10], | |
| ["example_images/02.jpg", None, None, "a high quality photo of a <person> smiling in the snow", 2.0, 4, 40], | |
| ["example_images/01.jpg", None, None, "a pop figure of a <person>, she stands on a white background", 2.0, 4, 20], | |
| ], | |
| inputs=[img1, img2, img3, prompt_input, w_input, num_samples_input, seed_input], | |
| label="Example Prompts" | |
| ) | |
| with gr.Column(scale=1): | |
| output_gallery = gr.Gallery(label="Generated Images", columns=3) | |
| generate_button.click( | |
| fn=generate_images, | |
| inputs=[img1, img2, img3, prompt_input, w_input, num_samples_input, seed_input], | |
| outputs=output_gallery | |
| ) | |
| demo.launch() | |