Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,18 +1,15 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
| 4 |
-
from diffusers.utils import load_image
|
| 5 |
-
from diffusers import
|
| 6 |
-
StableDiffusionPipeline,
|
| 7 |
-
StableDiffusionControlNetPipeline,
|
| 8 |
-
ControlNetModel
|
| 9 |
-
)
|
| 10 |
from peft import PeftModel, LoraConfig
|
| 11 |
from controlnet_aux import HEDdetector
|
| 12 |
from PIL import Image
|
| 13 |
import cv2 as cv
|
| 14 |
import os
|
| 15 |
-
|
|
|
|
| 16 |
|
| 17 |
MAX_SEED = np.iinfo(np.int32).max
|
| 18 |
MAX_IMAGE_SIZE = 1024
|
|
@@ -23,198 +20,141 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
| 23 |
model_id_default = "CompVis/stable-diffusion-v1-4"
|
| 24 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def get_lora_sd_pipeline(
|
| 43 |
-
ckpt_dir='./lora_logos',
|
| 44 |
-
base_model_name_or_path=None,
|
| 45 |
-
dtype=torch.float16,
|
| 46 |
-
adapter_name="default",
|
| 47 |
-
controlnet=None
|
| 48 |
-
):
|
| 49 |
-
|
| 50 |
-
unet_sub_dir = os.path.join(ckpt_dir, "unet")
|
| 51 |
-
text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
|
| 52 |
-
|
| 53 |
-
if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
|
| 54 |
-
config = LoraConfig.from_pretrained(text_encoder_sub_dir)
|
| 55 |
-
base_model_name_or_path = config.base_model_name_or_path
|
| 56 |
-
|
| 57 |
-
if base_model_name_or_path is None:
|
| 58 |
-
raise ValueError("Please specify the base model name or path")
|
| 59 |
-
|
| 60 |
-
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
| 61 |
-
base_model_name_or_path,
|
| 62 |
-
torch_dtype=dtype,
|
| 63 |
-
controlnet=controlnet,
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
before_params = pipe.unet.parameters()
|
| 67 |
-
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
|
| 68 |
-
pipe.unet.set_adapter(adapter_name)
|
| 69 |
-
after_params = pipe.unet.parameters()
|
| 70 |
-
print("Parameters changed:", any(torch.any(b != a) for b, a in zip(before_params, after_params)))
|
| 71 |
-
|
| 72 |
-
if os.path.exists(text_encoder_sub_dir):
|
| 73 |
-
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
def
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
|
|
|
| 87 |
|
| 88 |
-
return
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
return torch.nn.functional.pad(prompt_embeds, (0, 0, 0, max_length - prompt_embeds.shape[1])), \
|
| 93 |
-
torch.nn.functional.pad(negative_prompt_embeds, (0, 0, 0, max_length - negative_prompt_embeds.shape[1]))
|
| 94 |
-
|
| 95 |
-
def map_edge_detection(image_path: str) -> Image:
|
| 96 |
-
source_img = load_image(image_path).convert('RGB')
|
| 97 |
-
edges = cv.Canny(np.array(source_img), 80, 160)
|
| 98 |
-
edges = np.repeat(edges[:, :, None], 3, axis=2)
|
| 99 |
-
final_image = Image.fromarray(edges)
|
| 100 |
-
return final_image
|
| 101 |
|
| 102 |
-
def
|
| 103 |
-
global hed
|
| 104 |
-
if not hed:
|
| 105 |
-
hed = HEDdetector.from_pretrained('lllyasviel/Annotators')
|
| 106 |
-
|
| 107 |
image = load_image(image_path).convert('RGB')
|
| 108 |
-
scribble_image = hed(image)
|
| 109 |
-
image_np = np.array(scribble_image)
|
| 110 |
-
image_np = cv.medianBlur(image_np, 3)
|
| 111 |
-
image = cv.convertScaleAbs(image_np, alpha=1.5, beta=0)
|
| 112 |
-
final_image = Image.fromarray(image)
|
| 113 |
-
return final_image
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
pipe = get_lora_sd_pipeline(
|
| 118 |
-
ckpt_dir='./lora_logos',
|
| 119 |
-
base_model_name_or_path=model_id_default,
|
| 120 |
-
dtype=torch_dtype,
|
| 121 |
-
controlnet=controlnet
|
| 122 |
-
).to(device)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
def infer(
|
| 127 |
-
prompt,
|
| 128 |
-
negative_prompt,
|
| 129 |
-
width=512,
|
| 130 |
-
height=512,
|
| 131 |
-
num_inference_steps=20,
|
| 132 |
-
model_id='CompVis/stable-diffusion-v1-4',
|
| 133 |
-
seed=42,
|
| 134 |
-
guidance_scale=7.0,
|
| 135 |
-
lora_scale=0.5,
|
| 136 |
-
cn_enable=False,
|
| 137 |
-
cn_strength=0.0,
|
| 138 |
-
cn_mode='edge_detection',
|
| 139 |
-
cn_image=None,
|
| 140 |
-
ip_enable=False,
|
| 141 |
-
ip_scale=0.5,
|
| 142 |
-
ip_image=None,
|
| 143 |
-
progress=gr.Progress(track_tqdm=True)
|
| 144 |
-
):
|
| 145 |
|
| 146 |
-
|
|
|
|
|
|
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
torch_dtype=torch_dtype
|
| 159 |
-
)
|
| 160 |
-
controlnet_changed = True
|
| 161 |
-
else:
|
| 162 |
-
cn_strength = 0.0 # отключаем контролнет принудительно
|
| 163 |
-
|
| 164 |
-
if model_id != pipe._name_or_path:
|
| 165 |
-
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
| 166 |
-
model_id,
|
| 167 |
-
torch_dtype=torch_dtype,
|
| 168 |
-
controlnet=controlnet,
|
| 169 |
-
controlnet_conditioning_scale=cn_strength,
|
| 170 |
-
).to(device)
|
| 171 |
-
elif (model_id == pipe._name_or_path) and controlnet_changed:
|
| 172 |
-
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
| 173 |
-
model_id,
|
| 174 |
-
torch_dtype=torch_dtype,
|
| 175 |
-
controlnet=controlnet,
|
| 176 |
-
controlnet_conditioning_scale=cn_strength,
|
| 177 |
-
).to(device)
|
| 178 |
-
print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
|
| 179 |
-
print(f"LoRA scale applied: {lora_scale}")
|
| 180 |
-
pipe.fuse_lora(lora_scale=lora_scale)
|
| 181 |
-
elif (model_id == pipe._name_or_path) and not controlnet_changed:
|
| 182 |
-
print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
|
| 183 |
-
print(f"LoRA scale applied: {lora_scale}")
|
| 184 |
-
pipe.fuse_lora(lora_scale=lora_scale)
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
|
| 189 |
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
'num_inference_steps': num_inference_steps,
|
| 195 |
-
'width': width,
|
| 196 |
-
'height': height,
|
| 197 |
-
'generator': generator,
|
| 198 |
-
}
|
| 199 |
-
|
| 200 |
-
if cn_enable:
|
| 201 |
-
params['controlnet_conditioning_scale'] = cn_strength
|
| 202 |
-
if cn_mode == 'edge_detection':
|
| 203 |
-
control_image = map_edge_detection(cn_image)
|
| 204 |
-
elif cn_mode == 'scribble':
|
| 205 |
-
control_image = map_scribble(cn_image)
|
| 206 |
-
params['image'] = control_image
|
| 207 |
-
|
| 208 |
-
if ip_enable:
|
| 209 |
-
pipe.load_ip_adapter(
|
| 210 |
-
IP_ADAPTER,
|
| 211 |
-
subfolder="models",
|
| 212 |
-
weight_name=IP_ADAPTER_WEIGHT_NAME,
|
| 213 |
)
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
css = """
|
| 220 |
#col-container {
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
| 4 |
+
from diffusers.utils import load_image
|
| 5 |
+
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from peft import PeftModel, LoraConfig
|
| 7 |
from controlnet_aux import HEDdetector
|
| 8 |
from PIL import Image
|
| 9 |
import cv2 as cv
|
| 10 |
import os
|
| 11 |
+
from functools import lru_cache
|
| 12 |
+
from contextlib import contextmanager
|
| 13 |
|
| 14 |
MAX_SEED = np.iinfo(np.int32).max
|
| 15 |
MAX_IMAGE_SIZE = 1024
|
|
|
|
| 20 |
model_id_default = "CompVis/stable-diffusion-v1-4"
|
| 21 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 22 |
|
| 23 |
+
class PipelineManager:
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.pipe = None
|
| 26 |
+
self.current_model = None
|
| 27 |
+
self.controlnet_cache = {}
|
| 28 |
+
self.hed = None
|
| 29 |
+
|
| 30 |
+
@lru_cache(maxsize=2)
|
| 31 |
+
def get_controlnet(self, model_name: str) -> ControlNetModel:
|
| 32 |
+
if model_name not in self.controlnet_cache:
|
| 33 |
+
self.controlnet_cache[model_name] = ControlNetModel.from_pretrained(
|
| 34 |
+
model_name,
|
| 35 |
+
cache_dir="./models_cache",
|
| 36 |
+
torch_dtype=torch_dtype
|
| 37 |
+
).to(device)
|
| 38 |
+
return self.controlnet_cache[model_name]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
def get_hed_detector(self):
|
| 41 |
+
if self.hed is None:
|
| 42 |
+
self.hed = HEDdetector.from_pretrained('lllyasviel/Annotators')
|
| 43 |
+
return self.hed
|
| 44 |
+
|
| 45 |
+
def initialize_pipeline(self, model_id, controlnet_model):
|
| 46 |
+
controlnet = self.get_controlnet(controlnet_model)
|
| 47 |
+
if not self.pipe or model_id != self.current_model:
|
| 48 |
+
self.pipe = self.create_pipeline(model_id, controlnet)
|
| 49 |
+
self.current_model = model_id
|
| 50 |
+
return self.pipe
|
| 51 |
+
|
| 52 |
+
def create_pipeline(self, model_id, controlnet):
|
| 53 |
+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
| 54 |
+
model_id,
|
| 55 |
+
torch_dtype=torch_dtype,
|
| 56 |
+
controlnet=controlnet,
|
| 57 |
+
cache_dir="./models_cache"
|
| 58 |
+
).to(device)
|
| 59 |
+
|
| 60 |
+
if os.path.exists('./lora_logos'):
|
| 61 |
+
pipe = self.load_lora_adapters(pipe)
|
| 62 |
+
|
| 63 |
+
return pipe
|
| 64 |
|
| 65 |
+
def load_lora_adapters(self, pipe):
|
| 66 |
+
unet_dir = os.path.join('./lora_logos', "unet")
|
| 67 |
+
text_encoder_dir = os.path.join('./lora_logos', "text_encoder")
|
| 68 |
+
|
| 69 |
+
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_dir, adapter_name="default")
|
| 70 |
+
if os.path.exists(text_encoder_dir):
|
| 71 |
+
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_dir)
|
| 72 |
+
|
| 73 |
+
return pipe.to(device)
|
| 74 |
+
|
| 75 |
+
@contextmanager
|
| 76 |
+
def torch_inference_mode():
|
| 77 |
+
with torch.inference_mode(), torch.autocast(device.type):
|
| 78 |
+
yield
|
| 79 |
+
|
| 80 |
+
def process_embeddings(prompt, negative_prompt, tokenizer, text_encoder):
|
| 81 |
+
def process_text(text):
|
| 82 |
+
tokens = tokenizer(text, return_tensors="pt", truncation=False).input_ids
|
| 83 |
+
chunks = [tokens[:, i:i+77].to(device) for i in range(0, tokens.size(1), 77)]
|
| 84 |
+
return torch.cat([text_encoder(chunk)[0] for chunk in chunks], dim=1)
|
| 85 |
|
| 86 |
+
prompt_emb = process_text(prompt)
|
| 87 |
+
negative_emb = process_text(negative_prompt)
|
| 88 |
+
max_len = max(prompt_emb.size(1), negative_emb.size(1))
|
| 89 |
|
| 90 |
+
return (
|
| 91 |
+
torch.nn.functional.pad(prompt_emb, (0, 0, 0, max_len - prompt_emb.size(1))),
|
| 92 |
+
torch.nn.functional.pad(negative_emb, (0, 0, 0, max_len - negative_emb.size(1)))
|
| 93 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
+
def process_control_image(image_path: str, processor: str, hed_detector) -> Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
image = load_image(image_path).convert('RGB')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
+
if processor == 'edge_detection':
|
| 99 |
+
edges = cv.Canny(np.array(image), 80, 160)
|
| 100 |
+
return Image.fromarray(np.repeat(edges[:, :, None], 3, axis=2))
|
| 101 |
|
| 102 |
+
if processor == 'scribble':
|
| 103 |
+
scribble = hed_detector(image)
|
| 104 |
+
processed = cv.medianBlur(np.array(scribble), 3)
|
| 105 |
+
return Image.fromarray(cv.convertScaleAbs(processed, alpha=1.5))
|
| 106 |
+
|
| 107 |
+
pipeline_mgr = PipelineManager()
|
| 108 |
+
controlnet_models = {
|
| 109 |
+
"edge_detection": "lllyasviel/sd-controlnet-canny",
|
| 110 |
+
"scribble": "lllyasviel/sd-controlnet-scribble"
|
| 111 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
+
def infer(**kwargs):
|
| 114 |
+
generator = torch.Generator(device).manual_seed(kwargs['seed'])
|
|
|
|
| 115 |
|
| 116 |
+
with torch_inference_mode():
|
| 117 |
+
pipe = pipeline_mgr.initialize_pipeline(
|
| 118 |
+
kwargs['model_id'],
|
| 119 |
+
controlnet_models.get(kwargs['cn_mode'], controlnet_models['edge_detection'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
)
|
| 121 |
+
|
| 122 |
+
if kwargs['cn_enable'] and not kwargs['cn_image']:
|
| 123 |
+
raise gr.Error("ControlNet enabled but no image provided!")
|
| 124 |
+
|
| 125 |
+
prompt_emb, negative_emb = process_embeddings(
|
| 126 |
+
kwargs['prompt'],
|
| 127 |
+
kwargs['negative_prompt'],
|
| 128 |
+
pipe.tokenizer,
|
| 129 |
+
pipe.text_encoder
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
params = {
|
| 133 |
+
'prompt_embeds': prompt_emb,
|
| 134 |
+
'negative_prompt_embeds': negative_emb,
|
| 135 |
+
'guidance_scale': kwargs['guidance_scale'],
|
| 136 |
+
'num_inference_steps': kwargs['num_inference_steps'],
|
| 137 |
+
'width': kwargs['width'],
|
| 138 |
+
'height': kwargs['height'],
|
| 139 |
+
'generator': generator
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
if kwargs['cn_enable']:
|
| 143 |
+
params['image'] = process_control_image(
|
| 144 |
+
kwargs['cn_image'],
|
| 145 |
+
kwargs['cn_mode'],
|
| 146 |
+
pipeline_mgr.get_hed_detector()
|
| 147 |
+
)
|
| 148 |
+
params['controlnet_conditioning_scale'] = kwargs['cn_strength']
|
| 149 |
+
|
| 150 |
+
if kwargs.get('ip_enable', False):
|
| 151 |
+
pipe.load_ip_adapter(IP_ADAPTER, subfolder="models", weight_name=IP_ADAPTER_WEIGHT_NAME)
|
| 152 |
+
params['ip_adapter_image'] = load_image(kwargs['ip_image']).convert('RGB')
|
| 153 |
+
pipe.set_ip_adapter_scale(kwargs.get('ip_scale', 0.5))
|
| 154 |
+
|
| 155 |
+
pipe.fuse_lora(lora_scale=kwargs.get('lora_scale', 0.5))
|
| 156 |
+
|
| 157 |
+
return pipe(**params).images[0]
|
| 158 |
|
| 159 |
css = """
|
| 160 |
#col-container {
|