|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import random |
|
|
import torch |
|
|
import spaces |
|
|
|
|
|
|
|
|
from diffusers import DiffusionPipeline |
|
|
|
|
|
from pipelines.pipeline_tag_stablediffusion import StableDiffusionTangentialDecomposedPipeline |
|
|
from pipelines.pipeline_tag_stablediffusion3 import StableDiffusion3TangentialDecomposedPipeline |
|
|
from pipelines.pipeline_tag_stablediffusionXL import StableDiffusionXLTangentialDecomposedPipeline |
|
|
|
|
|
|
|
|
MODEL_MAP = { |
|
|
"SD 1.5": "runwayml/stable-diffusion-v1-5", |
|
|
"SD 2.1": "stabilityai/stable-diffusion-2-1", |
|
|
"SDXL": "stabilityai/stable-diffusion-xl-base-1.0", |
|
|
"SD 3": "stabilityai/stable-diffusion-3-medium-diffusers", |
|
|
} |
|
|
RESOLUTION_MAP = { "SD 1.5": 512, "SD 2.1": 768, "SDXL": 1024, "SD 3": 1024 } |
|
|
SEED_MAP = { "SD 1.5": 850728, "SD 2.1": 944905, "SDXL": 450040818, "SD 3": 282386105 } |
|
|
TAG_SCALE_MAP = { |
|
|
"SD 1.5": 1.15, |
|
|
"SD 2.1": 1.15, |
|
|
"SDXL": 1.20, |
|
|
"SD 3": 1.08 |
|
|
} |
|
|
PIPELINE_MAP = { |
|
|
"SD 1.5": StableDiffusionTangentialDecomposedPipeline, |
|
|
"SD 2.1": StableDiffusionTangentialDecomposedPipeline, |
|
|
"SDXL": StableDiffusionXLTangentialDecomposedPipeline, |
|
|
"SD 3": StableDiffusion3TangentialDecomposedPipeline, |
|
|
} |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
pipe = None |
|
|
current_model_id = None |
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
|
|
|
|
|
|
|
def load_pipeline(model_name, progress): |
|
|
global pipe, current_model_id |
|
|
model_id = MODEL_MAP[model_name] |
|
|
pipeline_class = PIPELINE_MAP[model_name] |
|
|
progress(0, desc=f"Loading model: {model_id} with {pipeline_class.__name__}...") |
|
|
if model_name == "SD 3": |
|
|
pipe = pipeline_class.from_pretrained(model_id, torch_dtype=torch_dtype, |
|
|
text_encoder_3=None, |
|
|
tokenizer_3=None,) |
|
|
else: |
|
|
pipe = pipeline_class.from_pretrained(model_id, torch_dtype=torch_dtype) |
|
|
pipe = pipe.to(device) |
|
|
current_model_id = model_id |
|
|
progress(1) |
|
|
|
|
|
def update_model_defaults(model_name): |
|
|
"""๋ชจ๋ธ ์ ํ์ ๋ฐ๋ผ ํด์๋, ์๋, ๋๋ค ์๋ ์ฒดํฌ๋ฐ์ค, TAG Scale์ ์
๋ฐ์ดํธํฉ๋๋ค.""" |
|
|
res = RESOLUTION_MAP[model_name] |
|
|
seed_val = SEED_MAP[model_name] |
|
|
tag_scale_val = TAG_SCALE_MAP[model_name] |
|
|
return ( |
|
|
gr.update(value=res), |
|
|
gr.update(value=res), |
|
|
gr.update(value=seed_val), |
|
|
gr.update(value=False), |
|
|
gr.update(value=tag_scale_val), |
|
|
) |
|
|
|
|
|
@spaces.GPU |
|
|
def infer( |
|
|
model_name, |
|
|
seed, |
|
|
randomize_seed, |
|
|
width, |
|
|
height, |
|
|
guidance_scale, |
|
|
num_inference_steps, |
|
|
guidance_start_timestep, |
|
|
guidance_end_timestep, |
|
|
progress=gr.Progress(track_tqdm=True), |
|
|
): |
|
|
global pipe, current_model_id |
|
|
|
|
|
model_id = MODEL_MAP[model_name] |
|
|
if model_id != current_model_id: |
|
|
gr.Info(f"Changing model to {model_name}. Please wait...") |
|
|
load_pipeline(model_name, progress) |
|
|
|
|
|
if randomize_seed: |
|
|
seed = random.randint(0, MAX_SEED) |
|
|
|
|
|
generator_custom = torch.Generator(device=device).manual_seed(int(seed)) |
|
|
generator_fixed = torch.Generator(device=device).manual_seed(int(seed)) |
|
|
|
|
|
unconditional_prompt = "" |
|
|
|
|
|
|
|
|
image_custom_scale = pipe( |
|
|
prompt=unconditional_prompt, guidance_scale=0., |
|
|
num_inference_steps=num_inference_steps, width=width, height=height, generator=generator_custom, |
|
|
sta_tpd=guidance_start_timestep, end_tpd=guidance_end_timestep, |
|
|
t_guidance_scale=guidance_scale |
|
|
).images[0] |
|
|
|
|
|
fixed_tag_value = 1.0 |
|
|
image_fixed_scale = pipe( |
|
|
prompt=unconditional_prompt, guidance_scale=0., |
|
|
num_inference_steps=num_inference_steps, width=width, height=height, generator=generator_fixed, |
|
|
sta_tpd=guidance_start_timestep, end_tpd=guidance_end_timestep, |
|
|
t_guidance_scale=fixed_tag_value |
|
|
).images[0] |
|
|
|
|
|
return [image_fixed_scale, image_custom_scale], seed |
|
|
|
|
|
|
|
|
css = """ |
|
|
#col-container { margin: 0 auto; max-width: 720px; } |
|
|
""" |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
|
with gr.Column(elem_id="col-container"): |
|
|
gr.Markdown("# Tangential Amplifying Guidance Demo") |
|
|
model_selector = gr.Dropdown( |
|
|
label="Select Model", choices=list(MODEL_MAP.keys()), value="SDXL" |
|
|
) |
|
|
with gr.Row(): |
|
|
prompt = gr.Text( |
|
|
label="Prompt (Disabled)", show_label=True, max_lines=1, |
|
|
placeholder="Unconditional generation mode. This input is ignored.", |
|
|
container=True, interactive=False, |
|
|
) |
|
|
run_button = gr.Button("Run", scale=0, variant="primary") |
|
|
|
|
|
|
|
|
result_slider = gr.ImageSlider( |
|
|
label="Result Comparison (Fixed Scale vs. Your Scale)", |
|
|
show_label=True |
|
|
) |
|
|
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
|
seed = gr.Slider( |
|
|
label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=SEED_MAP["SDXL"] |
|
|
) |
|
|
randomize_seed = gr.Checkbox(label="Randomize seed", value=False) |
|
|
with gr.Row(): |
|
|
width = gr.Slider( |
|
|
label="Width", minimum=256, maximum=1024, step=64, value=RESOLUTION_MAP["SDXL"] |
|
|
) |
|
|
height = gr.Slider( |
|
|
label="Height", minimum=256, maximum=1024, step=64, value=RESOLUTION_MAP["SDXL"] |
|
|
) |
|
|
with gr.Row(): |
|
|
guidance_scale = gr.Slider( |
|
|
label="TAG Scale", minimum=1.0, maximum=1.3, step=0.01, value=TAG_SCALE_MAP["SDXL"], |
|
|
) |
|
|
num_inference_steps = gr.Slider( |
|
|
label="Inference Steps", minimum=20, maximum=50, step=1, value=50 |
|
|
) |
|
|
with gr.Row(): |
|
|
guidance_start_timestep = gr.Slider( |
|
|
label="Guidance Start Timestep", minimum=0, maximum=1000, step=1, value=999 |
|
|
) |
|
|
guidance_end_timestep = gr.Slider( |
|
|
label="Guidance End Timestep", minimum=0, maximum=1000, step=1, value=0 |
|
|
) |
|
|
|
|
|
|
|
|
model_selector.change( |
|
|
fn=update_model_defaults, |
|
|
inputs=[model_selector], |
|
|
outputs=[width, height, seed, randomize_seed, guidance_scale], |
|
|
) |
|
|
|
|
|
|
|
|
run_button.click( |
|
|
fn=infer, |
|
|
inputs=[ |
|
|
model_selector, seed, randomize_seed, width, height, |
|
|
guidance_scale, num_inference_steps, |
|
|
guidance_start_timestep, guidance_end_timestep, |
|
|
], |
|
|
outputs=[result_slider, seed], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(debug=True) |