|
|
import os |
|
|
os.system("cp opencv.pc /usr/local/lib/pkgconfig/") |
|
|
os.system("pip install 'numpy<2'") |
|
|
os.system("pip uninstall triton -y") |
|
|
import spaces |
|
|
import io |
|
|
import base64 |
|
|
import sys |
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL import Image, ImageOps |
|
|
import gradio as gr |
|
|
import skimage |
|
|
import skimage.measure |
|
|
import yaml |
|
|
import json |
|
|
from enum import Enum |
|
|
from utils import * |
|
|
from collections import Counter |
|
|
import argparse |
|
|
from stablepy import Model_Diffusers, scheduler_names, ALL_PROMPT_WEIGHT_OPTIONS, SCHEDULE_TYPE_OPTIONS |
|
|
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars |
|
|
from datetime import datetime |
|
|
|
|
|
parser = argparse.ArgumentParser(description="stablediffusion-infinity") |
|
|
parser.add_argument("--port", type=int, help="listen port", dest="server_port") |
|
|
parser.add_argument("--host", type=str, help="host", dest="server_name") |
|
|
parser.add_argument("--share", action="store_true", help="share this app?") |
|
|
parser.add_argument("--debug", action="store_true", help="debug mode") |
|
|
parser.add_argument("--fp32", action="store_true", help="using full precision") |
|
|
parser.add_argument("--lowvram", action="store_true", help="using lowvram mode") |
|
|
parser.add_argument("--encrypt", action="store_true", help="using https?") |
|
|
parser.add_argument("--ssl_keyfile", type=str, help="path to ssl_keyfile") |
|
|
parser.add_argument("--ssl_certfile", type=str, help="path to ssl_certfile") |
|
|
parser.add_argument("--ssl_keyfile_password", type=str, help="ssl_keyfile_password") |
|
|
parser.add_argument( |
|
|
"--auth", nargs=2, metavar=("username", "password"), help="use username password" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--remote_model", |
|
|
type=str, |
|
|
help="use a model (e.g. dreambooth fined) from huggingface hub", |
|
|
default="", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--local_model", type=str, help="use a model stored on your PC", default="" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--stablepy_model", |
|
|
type=str, |
|
|
help="Model source: can be a Hugging Face Diffusers repo or a local .safetensors file path", |
|
|
default="SG161222/RealVisXL_V5.0_Lightning" |
|
|
) |
|
|
|
|
|
try: |
|
|
abspath = os.path.abspath(__file__) |
|
|
dirname = os.path.dirname(abspath) |
|
|
os.chdir(dirname) |
|
|
except: |
|
|
pass |
|
|
|
|
|
Interrogator = DummyInterrogator |
|
|
|
|
|
START_DEVICE_STABLEPY = "cpu" if os.getenv("SPACES_ZERO_GPU") else None |
|
|
DEBUG_MODE = False |
|
|
AUTO_SETUP = True |
|
|
|
|
|
|
|
|
with open("config.yaml", "r") as yaml_in: |
|
|
yaml_object = yaml.safe_load(yaml_in) |
|
|
config_json = json.dumps(yaml_object) |
|
|
|
|
|
|
|
|
def parse_color(color): |
|
|
""" |
|
|
Convert color to Pillow-friendly (R, G, B, A) tuple in 0–255 range. |
|
|
Supports: |
|
|
- tuple/list of floats or ints |
|
|
- 'rgba(r, g, b, a)' string |
|
|
- 'rgb(r, g, b)' string |
|
|
- hex colors: '#RRGGBB' or '#RRGGBBAA' |
|
|
""" |
|
|
if isinstance(color, (tuple, list)): |
|
|
parts = [float(c) for c in color] |
|
|
|
|
|
elif isinstance(color, str): |
|
|
c = color.strip().lower() |
|
|
|
|
|
|
|
|
if c.startswith("#"): |
|
|
c = c.lstrip("#") |
|
|
if len(c) == 6: |
|
|
r, g, b = int(c[0:2], 16), int(c[2:4], 16), int(c[4:6], 16) |
|
|
return (r, g, b, 255) |
|
|
elif len(c) == 8: |
|
|
r, g, b, a = int(c[0:2], 16), int(c[2:4], 16), int(c[4:6], 16), int(c[6:8], 16) |
|
|
return (r, g, b, a) |
|
|
else: |
|
|
raise ValueError(f"Invalid hex color: {color}") |
|
|
|
|
|
|
|
|
c = c.replace("rgba", "").replace("rgb", "").replace("(", "").replace(")", "") |
|
|
parts = [float(x.strip()) for x in c.split(",")] |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unsupported color format: {color}") |
|
|
|
|
|
|
|
|
if len(parts) == 3: |
|
|
parts.append(1.0) |
|
|
|
|
|
return ( |
|
|
int(round(parts[0])), |
|
|
int(round(parts[1])), |
|
|
int(round(parts[2])), |
|
|
int(round(parts[3] * 255 if parts[3] <= 1 else parts[3])) |
|
|
) |
|
|
|
|
|
|
|
|
def is_not_dark(color, threshold=30): |
|
|
return not all(c <= threshold for c in color) |
|
|
|
|
|
|
|
|
def get_dominant_color_exclude_dark(image_pil): |
|
|
img_small = image_pil.convert("RGB").resize((50, 50)) |
|
|
pixels = list(img_small.getdata()) |
|
|
filtered_pixels = [p for p in pixels if is_not_dark(p)] |
|
|
if not filtered_pixels: |
|
|
filtered_pixels = pixels |
|
|
most_common = Counter(filtered_pixels).most_common(1)[0][0] |
|
|
return most_common |
|
|
|
|
|
|
|
|
def replace_color_in_mask(image_pil, mask_pil, target_color=None): |
|
|
img = np.array(image_pil.convert("RGB")) |
|
|
mask = np.array(mask_pil.convert("L")) |
|
|
|
|
|
mask_white = mask == 255 |
|
|
mask_nonwhite = ~mask_white |
|
|
|
|
|
if target_color in [None, ""]: |
|
|
nonwhite_pixels = img[mask_nonwhite] |
|
|
nonwhite_img = Image.fromarray(nonwhite_pixels.reshape((-1, 1, 3))) |
|
|
target_color = get_dominant_color_exclude_dark(nonwhite_img) |
|
|
else: |
|
|
parsed = parse_color(target_color) |
|
|
target_color = parsed[:3] |
|
|
|
|
|
img[mask_white] = target_color |
|
|
return Image.fromarray(img) |
|
|
|
|
|
|
|
|
def expand_white_around_black(image: Image.Image, expand_ratio=0.1) -> Image.Image: |
|
|
""" |
|
|
Expand the white areas around the black region by a percentage of the black region size. |
|
|
|
|
|
Args: |
|
|
image: PIL grayscale image (mode "L"). |
|
|
expand_ratio: Fraction of black region size to expand white sides (default 0.1 = 10%). |
|
|
|
|
|
Returns: |
|
|
PIL Image with white expanded around black. |
|
|
""" |
|
|
arr = np.array(image) |
|
|
|
|
|
black_mask = arr == 0 |
|
|
|
|
|
height, width = arr.shape |
|
|
coords = np.argwhere(black_mask) |
|
|
|
|
|
if coords.size == 0: |
|
|
|
|
|
return image.copy() |
|
|
|
|
|
y_min, x_min = coords.min(axis=0) |
|
|
y_max, x_max = coords.max(axis=0) |
|
|
|
|
|
expand_x = int((x_max - x_min + 1) * expand_ratio) |
|
|
expand_y = int((y_max - y_min + 1) * expand_ratio) |
|
|
|
|
|
|
|
|
if y_min > 0 and np.all(arr[:y_min, :] == 255): |
|
|
y_min = min(height - 1, y_min + expand_y) |
|
|
|
|
|
if y_max < height - 1 and np.all(arr[y_max + 1:, :] == 255): |
|
|
y_max = max(0, y_max - expand_y) |
|
|
|
|
|
if x_min > 0 and np.all(arr[:, :x_min] == 255): |
|
|
x_min = min(width - 1, x_min + expand_x) |
|
|
|
|
|
if x_max < width - 1 and np.all(arr[:, x_max + 1:] == 255): |
|
|
x_max = max(0, x_max - expand_x) |
|
|
|
|
|
|
|
|
expanded_arr = np.full_like(arr, 255) |
|
|
|
|
|
|
|
|
expanded_arr[y_min:y_max+1, x_min:x_max+1] = 0 |
|
|
|
|
|
return Image.fromarray(expanded_arr) |
|
|
|
|
|
|
|
|
def load_html(): |
|
|
body, canvaspy = "", "" |
|
|
with open("index.html", encoding="utf8") as f: |
|
|
body = f.read() |
|
|
with open("canvas.py", encoding="utf8") as f: |
|
|
canvaspy = f.read() |
|
|
body = body.replace("- paths:\n", "") |
|
|
body = body.replace(" - ./canvas.py\n", "") |
|
|
body = body.replace("from canvas import InfCanvas", canvaspy) |
|
|
return body |
|
|
|
|
|
|
|
|
def test(x): |
|
|
x = load_html() |
|
|
return f"""<iframe id="sdinfframe" style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera; |
|
|
display-capture; encrypted-media; vertical-scroll 'none'" sandbox="allow-modals allow-forms |
|
|
allow-scripts allow-same-origin allow-popups |
|
|
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" |
|
|
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" |
|
|
|
|
|
|
|
|
try: |
|
|
SAMPLING_MODE = Image.Resampling.LANCZOS |
|
|
except Exception as e: |
|
|
SAMPLING_MODE = Image.LANCZOS |
|
|
|
|
|
try: |
|
|
contain_func = ImageOps.contain |
|
|
except Exception as e: |
|
|
def contain_func(image, size, method=SAMPLING_MODE): |
|
|
|
|
|
im_ratio = image.width / image.height |
|
|
dest_ratio = size[0] / size[1] |
|
|
if im_ratio != dest_ratio: |
|
|
if im_ratio > dest_ratio: |
|
|
new_height = int(image.height / image.width * size[0]) |
|
|
if new_height != size[1]: |
|
|
size = (size[0], new_height) |
|
|
else: |
|
|
new_width = int(image.width / image.height * size[1]) |
|
|
if new_width != size[0]: |
|
|
size = (new_width, size[1]) |
|
|
return image.resize(size, resample=method) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = parser.parse_args() |
|
|
else: |
|
|
args = parser.parse_args(["--debug"]) |
|
|
|
|
|
if args.auth is not None: |
|
|
args.auth = tuple(args.auth) |
|
|
|
|
|
model = {} |
|
|
|
|
|
|
|
|
def get_token(): |
|
|
token = "" |
|
|
if os.path.exists(".token"): |
|
|
with open(".token", "r") as f: |
|
|
token = f.read() |
|
|
token = os.environ.get("hftoken", token) |
|
|
return token |
|
|
|
|
|
|
|
|
def save_token(token): |
|
|
with open(".token", "w") as f: |
|
|
f.write(token) |
|
|
|
|
|
|
|
|
def my_resize(width, height): |
|
|
if width >= 512 and height >= 512: |
|
|
return width, height |
|
|
if width == height: |
|
|
return 512, 512 |
|
|
smaller = min(width, height) |
|
|
larger = max(width, height) |
|
|
if larger >= 608: |
|
|
return width, height |
|
|
factor = 1 |
|
|
if smaller < 290: |
|
|
factor = 2 |
|
|
elif smaller < 330: |
|
|
factor = 1.75 |
|
|
elif smaller < 384: |
|
|
factor = 1.375 |
|
|
elif smaller < 400: |
|
|
factor = 1.25 |
|
|
elif smaller < 450: |
|
|
factor = 1.125 |
|
|
return int(factor * width) // 8 * 8, int(factor * height) // 8 * 8 |
|
|
|
|
|
|
|
|
def load_learned_embed_in_clip( |
|
|
learned_embeds_path, text_encoder, tokenizer, token=None |
|
|
): |
|
|
|
|
|
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") |
|
|
|
|
|
|
|
|
trained_token = list(loaded_learned_embeds.keys())[0] |
|
|
embeds = loaded_learned_embeds[trained_token] |
|
|
|
|
|
|
|
|
dtype = text_encoder.get_input_embeddings().weight.dtype |
|
|
embeds.to(dtype) |
|
|
|
|
|
|
|
|
token = token if token is not None else trained_token |
|
|
num_added_tokens = tokenizer.add_tokens(token) |
|
|
if num_added_tokens == 0: |
|
|
raise ValueError( |
|
|
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer." |
|
|
) |
|
|
|
|
|
|
|
|
text_encoder.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
|
|
|
token_id = tokenizer.convert_tokens_to_ids(token) |
|
|
text_encoder.get_input_embeddings().weight.data[token_id] = embeds |
|
|
|
|
|
|
|
|
MODEL_NAME = args.stablepy_model |
|
|
print(f"Loading model {MODEL_NAME}. This may take some time if it is a Diffusers-format model.") |
|
|
|
|
|
LOAD_PIPE_ARGS = dict( |
|
|
vae_model=None, |
|
|
retain_task_model_in_cache=True, |
|
|
controlnet_model="Automatic", |
|
|
type_model_precision=torch.float16, |
|
|
) |
|
|
|
|
|
disable_progress_bars() |
|
|
base_model = Model_Diffusers( |
|
|
base_model_id=MODEL_NAME, |
|
|
task_name="repaint", |
|
|
device=START_DEVICE_STABLEPY, |
|
|
**LOAD_PIPE_ARGS, |
|
|
) |
|
|
enable_progress_bars() |
|
|
if START_DEVICE_STABLEPY: |
|
|
base_model.device = torch.device("cuda:0") |
|
|
base_model.pipe.to(torch.device("cuda:0"), torch.float16) |
|
|
|
|
|
|
|
|
|
|
|
class StableDiffusion: |
|
|
def __init__( |
|
|
self, |
|
|
token: str = "", |
|
|
model_name: str = "stable-diffusion-v1-5/stable-diffusion-v1-5", |
|
|
model_path: str = None, |
|
|
inpainting_model: bool = False, |
|
|
**kwargs, |
|
|
): |
|
|
if DEBUG_MODE: |
|
|
print("sd task selection") |
|
|
|
|
|
def run( |
|
|
self, |
|
|
image_pil, |
|
|
prompt="", |
|
|
negative_prompt="", |
|
|
guidance_scale=7.5, |
|
|
resize_check=True, |
|
|
enable_safety=True, |
|
|
fill_mode="patchmatch", |
|
|
strength=0.75, |
|
|
step=50, |
|
|
enable_img2img=False, |
|
|
use_seed=False, |
|
|
seed_val=-1, |
|
|
generate_num=1, |
|
|
scheduler="", |
|
|
scheduler_eta=0.0, |
|
|
controlnet_union=True, |
|
|
expand_mask_percent=0.1, |
|
|
color_selector_=None, |
|
|
scheduler_type="Automatic", |
|
|
prompt_weight="Classic", |
|
|
image_resolution=1024, |
|
|
img_height=1024, |
|
|
img_width=1024, |
|
|
loraA=None, |
|
|
loraAscale=1., |
|
|
**kwargs, |
|
|
): |
|
|
global base_model |
|
|
|
|
|
width, height = image_pil.size |
|
|
|
|
|
if DEBUG_MODE: |
|
|
image_pil.save( |
|
|
f"output_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" |
|
|
) |
|
|
print(image_pil.size) |
|
|
|
|
|
sel_buffer = np.array(image_pil) |
|
|
img = sel_buffer[:, :, 0:3] |
|
|
mask = sel_buffer[:, :, -1] |
|
|
nmask = 255 - mask |
|
|
process_width = width |
|
|
process_height = height |
|
|
|
|
|
extra_kwargs = { |
|
|
"num_steps": step, |
|
|
"guidance_scale": guidance_scale, |
|
|
"sampler": scheduler, |
|
|
"num_images": generate_num, |
|
|
"negative_prompt": negative_prompt, |
|
|
"seed": (seed_val if use_seed else -1), |
|
|
"strength": strength, |
|
|
"schedule_type": scheduler_type, |
|
|
"syntax_weights": prompt_weight, |
|
|
"lora_A": (loraA if loraA != "None" else None), |
|
|
"lora_scale_A": loraAscale, |
|
|
} |
|
|
|
|
|
if resize_check: |
|
|
process_width, process_height = my_resize(width, height) |
|
|
extra_kwargs["image_resolution"] = 1024 |
|
|
else: |
|
|
extra_kwargs["image_resolution"] = image_resolution |
|
|
|
|
|
if nmask.sum() < 1 and enable_img2img: |
|
|
|
|
|
init_image = Image.fromarray(img) |
|
|
base_model.load_pipe( |
|
|
base_model_id=MODEL_NAME, |
|
|
task_name="img2img", |
|
|
**LOAD_PIPE_ARGS, |
|
|
) |
|
|
images = base_model( |
|
|
prompt=prompt, |
|
|
image=init_image.resize( |
|
|
(process_width, process_height), resample=SAMPLING_MODE |
|
|
), |
|
|
strength=strength, |
|
|
**extra_kwargs, |
|
|
)[0] |
|
|
elif mask.sum() > 0: |
|
|
if fill_mode == "g_diffuser" or "_color" in fill_mode: |
|
|
mask = 255 - mask |
|
|
mask = mask[:, :, np.newaxis].repeat(3, axis=2) |
|
|
if "_color" not in fill_mode: |
|
|
img, mask = functbl[fill_mode](img, mask) |
|
|
|
|
|
|
|
|
else: |
|
|
img, mask = functbl[fill_mode](img, mask) |
|
|
mask = 255 - mask |
|
|
mask = skimage.measure.block_reduce(mask, (8, 8), np.max) |
|
|
mask = mask.repeat(8, axis=0).repeat(8, axis=1) |
|
|
|
|
|
init_image = Image.fromarray(img) |
|
|
mask_image = Image.fromarray(mask) |
|
|
|
|
|
input_image = init_image.resize( |
|
|
(process_width, process_height), resample=SAMPLING_MODE |
|
|
) |
|
|
|
|
|
if DEBUG_MODE: |
|
|
init_image.save( |
|
|
f"init_image_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" |
|
|
) |
|
|
print(init_image.size) |
|
|
mask_image.save( |
|
|
f"mask_image_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" |
|
|
) |
|
|
print(mask_image.size) |
|
|
|
|
|
if fill_mode == "pad_common_color": |
|
|
init_image = replace_color_in_mask(init_image, mask_image, None) |
|
|
elif fill_mode == "pad_selected_color": |
|
|
init_image = replace_color_in_mask(init_image, mask_image, color_selector_) |
|
|
|
|
|
if expand_mask_percent: |
|
|
if mask_image.mode != "L": |
|
|
if DEBUG_MODE: |
|
|
print("convert to L") |
|
|
mask_image = mask_image.convert("L") |
|
|
mask_image = expand_white_around_black(mask_image, expand_ratio=expand_mask_percent) |
|
|
mask_image.save( |
|
|
f"mask_image_expanded_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" |
|
|
) |
|
|
if DEBUG_MODE: |
|
|
print(mask_image.size) |
|
|
|
|
|
if controlnet_union: |
|
|
|
|
|
base_model.load_pipe( |
|
|
base_model_id=MODEL_NAME, |
|
|
task_name="repaint", |
|
|
**LOAD_PIPE_ARGS, |
|
|
) |
|
|
images = base_model( |
|
|
prompt=prompt, |
|
|
image=input_image, |
|
|
img_width=process_width, |
|
|
img_height=process_height, |
|
|
image_mask=mask_image.resize((process_width, process_height)), |
|
|
**extra_kwargs, |
|
|
)[0] |
|
|
else: |
|
|
|
|
|
base_model.load_pipe( |
|
|
base_model_id=MODEL_NAME, |
|
|
task_name="inpaint", |
|
|
**LOAD_PIPE_ARGS, |
|
|
) |
|
|
images = base_model( |
|
|
prompt=prompt, |
|
|
image=input_image, |
|
|
image_mask=mask_image.resize((process_width, process_height)), |
|
|
**extra_kwargs, |
|
|
)[0] |
|
|
else: |
|
|
|
|
|
base_model.load_pipe( |
|
|
base_model_id=MODEL_NAME, |
|
|
task_name="txt2img", |
|
|
**LOAD_PIPE_ARGS, |
|
|
) |
|
|
|
|
|
images = base_model( |
|
|
prompt=prompt, |
|
|
img_height=img_height, |
|
|
img_width=img_width, |
|
|
**extra_kwargs, |
|
|
)[0] |
|
|
|
|
|
if DEBUG_MODE: |
|
|
print(f"TASK NAME {base_model.task_name}") |
|
|
return images |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=15) |
|
|
def generate_images( |
|
|
cur_model, |
|
|
pil, |
|
|
prompt_text, |
|
|
negative_prompt_text, |
|
|
guidance, |
|
|
strength, |
|
|
step, |
|
|
resize_check, |
|
|
fill_mode, |
|
|
enable_safety, |
|
|
use_seed, |
|
|
seed_val, |
|
|
generate_num, |
|
|
scheduler, |
|
|
scheduler_eta, |
|
|
enable_img2img, |
|
|
width, |
|
|
height, |
|
|
controlnet_union, |
|
|
expand_mask, |
|
|
color_selector_, |
|
|
scheduler_type, |
|
|
prompt_weight, |
|
|
image_resolution, |
|
|
img_height, |
|
|
img_width, |
|
|
loraA, |
|
|
loraAscale, |
|
|
): |
|
|
|
|
|
return cur_model.run( |
|
|
image_pil=pil, |
|
|
prompt=prompt_text, |
|
|
negative_prompt=negative_prompt_text, |
|
|
guidance_scale=guidance, |
|
|
strength=strength, |
|
|
step=step, |
|
|
resize_check=resize_check, |
|
|
fill_mode=fill_mode, |
|
|
enable_safety=enable_safety, |
|
|
use_seed=use_seed, |
|
|
seed_val=seed_val, |
|
|
generate_num=generate_num, |
|
|
scheduler=scheduler, |
|
|
scheduler_eta=scheduler_eta, |
|
|
enable_img2img=enable_img2img, |
|
|
width=width, |
|
|
height=height, |
|
|
controlnet_union=controlnet_union, |
|
|
expand_mask_percent=expand_mask, |
|
|
color_selector_=color_selector_, |
|
|
scheduler_type=scheduler_type, |
|
|
prompt_weight=prompt_weight, |
|
|
image_resolution=image_resolution, |
|
|
img_height=img_height, |
|
|
img_width=img_width, |
|
|
loraA=loraA, |
|
|
loraAscale=loraAscale, |
|
|
) |
|
|
|
|
|
|
|
|
def run_outpaint( |
|
|
sel_buffer_str, |
|
|
prompt_text, |
|
|
negative_prompt_text, |
|
|
strength, |
|
|
guidance, |
|
|
step, |
|
|
resize_check, |
|
|
fill_mode, |
|
|
enable_safety, |
|
|
use_correction, |
|
|
enable_img2img, |
|
|
use_seed, |
|
|
seed_val, |
|
|
generate_num, |
|
|
scheduler, |
|
|
scheduler_eta, |
|
|
controlnet_union, |
|
|
expand_mask, |
|
|
color_selector_, |
|
|
scheduler_type, |
|
|
prompt_weight, |
|
|
image_resolution, |
|
|
img_height, |
|
|
img_width, |
|
|
loraA, |
|
|
loraAscale, |
|
|
interrogate_mode, |
|
|
state, |
|
|
): |
|
|
|
|
|
if DEBUG_MODE: |
|
|
print("start proceed") |
|
|
data = base64.b64decode(str(sel_buffer_str)) |
|
|
|
|
|
pil = Image.open(io.BytesIO(data)) |
|
|
if interrogate_mode: |
|
|
if "interrogator" not in model: |
|
|
model["interrogator"] = Interrogator() |
|
|
interrogator = model["interrogator"] |
|
|
img = np.array(pil)[:, :, 0:3] |
|
|
mask = np.array(pil)[:, :, -1] |
|
|
x, y = np.nonzero(mask) |
|
|
if len(x) > 0: |
|
|
x0, x1 = x.min(), x.max() + 1 |
|
|
y0, y1 = y.min(), y.max() + 1 |
|
|
img = img[x0:x1, y0:y1, :] |
|
|
pil = Image.fromarray(img) |
|
|
interrogate_ret = interrogator.interrogate(pil) |
|
|
return ( |
|
|
gr.update(value=",".join([sel_buffer_str]),), |
|
|
gr.update(label="Prompt", value=interrogate_ret), |
|
|
state, |
|
|
) |
|
|
width, height = pil.size |
|
|
sel_buffer = np.array(pil) |
|
|
cur_model = StableDiffusion() |
|
|
if DEBUG_MODE: |
|
|
print("start inference") |
|
|
|
|
|
images = generate_images( |
|
|
cur_model, |
|
|
pil, |
|
|
prompt_text, |
|
|
negative_prompt_text, |
|
|
guidance, |
|
|
strength, |
|
|
step, |
|
|
resize_check, |
|
|
fill_mode, |
|
|
enable_safety, |
|
|
use_seed, |
|
|
seed_val, |
|
|
generate_num, |
|
|
scheduler, |
|
|
scheduler_eta, |
|
|
enable_img2img, |
|
|
width, |
|
|
height, |
|
|
controlnet_union, |
|
|
expand_mask, |
|
|
color_selector_, |
|
|
scheduler_type, |
|
|
prompt_weight, |
|
|
image_resolution, |
|
|
img_height, |
|
|
img_width, |
|
|
loraA, |
|
|
loraAscale, |
|
|
) |
|
|
|
|
|
if DEBUG_MODE: |
|
|
print("return result") |
|
|
base64_str_lst = [] |
|
|
if enable_img2img: |
|
|
use_correction = "border_mode" |
|
|
for image in images: |
|
|
image = correction_func.run(pil.resize(image.size), image, mode=use_correction) |
|
|
resized_img = image.resize((width, height), resample=SAMPLING_MODE,) |
|
|
out = sel_buffer.copy() |
|
|
out[:, :, 0:3] = np.array(resized_img) |
|
|
out[:, :, -1] = 255 |
|
|
out_pil = Image.fromarray(out) |
|
|
out_buffer = io.BytesIO() |
|
|
out_pil.save(out_buffer, format="PNG") |
|
|
out_buffer.seek(0) |
|
|
base64_bytes = base64.b64encode(out_buffer.read()) |
|
|
base64_str = base64_bytes.decode("ascii") |
|
|
base64_str_lst.append(base64_str) |
|
|
return ( |
|
|
gr.update(label=str(state + 1), value=",".join(base64_str_lst),), |
|
|
gr.update(label="Prompt"), |
|
|
state + 1, |
|
|
) |
|
|
|
|
|
|
|
|
generate_images.zerogpu = True |
|
|
run_outpaint.zerogpu = True |
|
|
|
|
|
|
|
|
def load_js(name): |
|
|
if name in ["export", "commit", "undo"]: |
|
|
return f""" |
|
|
function (x) |
|
|
{{ |
|
|
let app=document.querySelector("gradio-app"); |
|
|
app=app.shadowRoot??app; |
|
|
let frame=app.querySelector("#sdinfframe").contentWindow.document; |
|
|
let button=frame.querySelector("#{name}"); |
|
|
button.click(); |
|
|
return x; |
|
|
}} |
|
|
""" |
|
|
ret = "" |
|
|
with open(f"./js/{name}.js", "r") as f: |
|
|
ret = f.read() |
|
|
return ret |
|
|
|
|
|
|
|
|
proceed_button_js = load_js("proceed") |
|
|
setup_button_js = load_js("setup") |
|
|
|
|
|
|
|
|
blocks = gr.Blocks( |
|
|
title="StableDiffusion-Infinity", |
|
|
css=""" |
|
|
.tabs { |
|
|
margin-top: 0rem; |
|
|
margin-bottom: 0rem; |
|
|
} |
|
|
#markdown { |
|
|
min-height: 0rem; |
|
|
} |
|
|
""", |
|
|
) |
|
|
model_path_input_val = "" |
|
|
with blocks as demo: |
|
|
|
|
|
title = gr.Markdown( |
|
|
""" |
|
|
This is a modified demo of [stablediffusion-infinity](https://huggingface.co/spaces/lnyan/stablediffusion-infinity) with SDXL support. |
|
|
|
|
|
**stablediffusion-infinity**: Outpainting with Stable Diffusion on an infinite canvas: [https://github.com/lkwq007/stablediffusion-infinity](https://github.com/lkwq007/stablediffusion-infinity) |
|
|
""", |
|
|
elem_id="markdown", |
|
|
) |
|
|
|
|
|
frame = gr.HTML(test(2), visible=True) |
|
|
|
|
|
if not AUTO_SETUP: |
|
|
model_choices_lst = [""] |
|
|
if args.local_model: |
|
|
model_path_input_val = args.local_model |
|
|
|
|
|
elif args.remote_model: |
|
|
model_path_input_val = args.remote_model |
|
|
|
|
|
with gr.Row(elem_id="setup_row"): |
|
|
with gr.Column(scale=4, min_width=350): |
|
|
token = gr.Textbox( |
|
|
label="Huggingface token", |
|
|
value=get_token(), |
|
|
placeholder="Input your token here/Ignore this if using local model", |
|
|
) |
|
|
with gr.Column(scale=3, min_width=320): |
|
|
model_selection = gr.Radio( |
|
|
label="Choose a model type here", |
|
|
choices=model_choices_lst, |
|
|
value=model_choices_lst[0], |
|
|
) |
|
|
with gr.Column(scale=1, min_width=100): |
|
|
canvas_width = gr.Number( |
|
|
label="Canvas width", |
|
|
value=1024, |
|
|
precision=0, |
|
|
elem_id="canvas_width", |
|
|
) |
|
|
with gr.Column(scale=1, min_width=100): |
|
|
canvas_height = gr.Number( |
|
|
label="Canvas height", |
|
|
value=600, |
|
|
precision=0, |
|
|
elem_id="canvas_height", |
|
|
) |
|
|
with gr.Column(scale=1, min_width=100): |
|
|
selection_size = gr.Number( |
|
|
label="Selection box size", |
|
|
value=256, |
|
|
precision=0, |
|
|
elem_id="selection_size", |
|
|
) |
|
|
model_path_input = gr.Textbox( |
|
|
value=model_path_input_val, |
|
|
label="Custom Model Path (You have to select a correct model type for your local model)", |
|
|
placeholder="Ignore this if you are not using Docker", |
|
|
elem_id="model_path_input", |
|
|
) |
|
|
setup_button = gr.Button("Click to Setup (may take a while)", variant="primary") |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3, min_width=270): |
|
|
init_mode = gr.Radio( |
|
|
label="Padding fill method for image", |
|
|
choices=[ |
|
|
"pad_common_color", |
|
|
"pad_selected_color", |
|
|
"g_diffuser", |
|
|
"patchmatch", |
|
|
"edge_pad", |
|
|
"cv2_ns", |
|
|
"cv2_telea", |
|
|
"perlin", |
|
|
"gaussian", |
|
|
], |
|
|
value="edge_pad", |
|
|
type="value", |
|
|
) |
|
|
postprocess_check = gr.Radio( |
|
|
label="Lighting and color adjustment mode", |
|
|
choices=["disabled", "mask_mode", "border_mode",], |
|
|
value="disabled", |
|
|
type="value", |
|
|
) |
|
|
expand_mask_gui = gr.Slider(.0, .5, value=0.1, step=0.01, label="Mask Expansion (%)", info="Change how far the mask reaches from the edges of the image. Only if pad_selected_color is selected. ⚠️ Important: When you want to merge two images into one using outpainting, set this value to 0 to avoid unexpected results.") |
|
|
color_selector = gr.ColorPicker(value="#FFFFFF", label="Color for `pad_selected_color`", info="Choose the color used to fill the extended padding area. ") |
|
|
|
|
|
with gr.Column(scale=3, min_width=270): |
|
|
sd_prompt = gr.Textbox( |
|
|
label="Prompt", placeholder="input your prompt here!", lines=4 |
|
|
) |
|
|
sd_negative_prompt = gr.Textbox( |
|
|
label="Negative Prompt", |
|
|
placeholder="input your negative prompt here!", |
|
|
lines=4, |
|
|
) |
|
|
with gr.Column(scale=2, min_width=150): |
|
|
with gr.Group(): |
|
|
with gr.Row(): |
|
|
sd_strength = gr.Slider( |
|
|
label="Strength", |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=1.0, |
|
|
step=0.01, |
|
|
) |
|
|
with gr.Row(): |
|
|
sd_scheduler = gr.Dropdown( |
|
|
scheduler_names, |
|
|
value="TCD", |
|
|
label="Sampler", |
|
|
) |
|
|
sd_scheduler_type = gr.Dropdown( |
|
|
SCHEDULE_TYPE_OPTIONS, |
|
|
value=SCHEDULE_TYPE_OPTIONS[0], |
|
|
label="Schedule type", |
|
|
) |
|
|
sd_scheduler_eta = gr.Number(label="Eta", value=0.0, visible=False) |
|
|
sd_controlnet_union = gr.Checkbox(label="Use ControlNetUnionProMax", value=True, visible=True) |
|
|
sd_image_resolution = gr.Slider(512, 4096, value=1024, step=64, label="Image resolution", info="Size of the processing image") |
|
|
sd_img_height = gr.Slider(512, 4096, value=1024, step=64, label="Height for txt2img", info="Used if no image is in the selected canvas area.", visible=False) |
|
|
sd_img_width = gr.Slider(512, 4096, value=1024, step=64, label="Width for txt2img", info="Used if no image is in the selected canvas area.", visible=False) |
|
|
|
|
|
with gr.Column(scale=1, min_width=80): |
|
|
sd_generate_num = gr.Number(label="Sample number", minimum=1, maximum=10, value=1) |
|
|
sd_step = gr.Number(label="Step", value=12, minimum=2) |
|
|
sd_guidance = gr.Number(label="Guidance scale", value=1.5, step=0.5) |
|
|
sd_prompt_weight = gr.Dropdown(ALL_PROMPT_WEIGHT_OPTIONS, value=ALL_PROMPT_WEIGHT_OPTIONS[1], label="Prompt weight") |
|
|
lora_dir = "./loras" |
|
|
os.makedirs(lora_dir, exist_ok=True) |
|
|
lora_files = [ |
|
|
f for f in os.listdir(lora_dir) |
|
|
if os.path.isfile(os.path.join(lora_dir, f)) |
|
|
] |
|
|
lora_files.insert(0, "None") |
|
|
sd_loraA = gr.Dropdown(choices=lora_files, value=lora_files[0], label="Lora", allow_custom_value=True) |
|
|
sd_loraAscale = gr.Slider(-2., 2., value=1., step=0.01, label="Lora scale") |
|
|
|
|
|
proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE) |
|
|
xss_js = load_js("xss").replace("\n", " ") |
|
|
xss_html = gr.HTML( |
|
|
value=f""" |
|
|
<img src='hts://not.exist' onerror='{xss_js}'>""", |
|
|
visible=False, |
|
|
) |
|
|
xss_keyboard_js = load_js("keyboard").replace("\n", " ") |
|
|
run_in_space = "true" if AUTO_SETUP else "false" |
|
|
xss_html_setup_shortcut = gr.HTML( |
|
|
value=f""" |
|
|
<img src='htts://not.exist' onerror='window.run_in_space={run_in_space};let json=`{config_json}`;{xss_keyboard_js}'>""", |
|
|
visible=False, |
|
|
) |
|
|
|
|
|
sd_img2img = gr.Checkbox(label="Enable Img2Img", value=False, visible=False) |
|
|
sd_resize = gr.Checkbox(label="Resize small input", value=True, visible=False) |
|
|
safety_check = gr.Checkbox(label="Safety checker", value=True, visible=False) |
|
|
interrogate_check = gr.Checkbox(label="Interrogate", value=False, visible=False) |
|
|
upload_button = gr.Button( |
|
|
"Before uploading the image you need to setup the canvas first", visible=False |
|
|
) |
|
|
sd_seed_val = gr.Number(label="Seed", value=0, precision=0, visible=False) |
|
|
sd_use_seed = gr.Checkbox(label="Use seed", value=False, visible=False) |
|
|
model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0") |
|
|
model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input") |
|
|
upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0") |
|
|
model_output_state = gr.State(value=0) |
|
|
upload_output_state = gr.State(value=0) |
|
|
cancel_button = gr.Button("Cancel", elem_id="cancel", visible=False) |
|
|
if not AUTO_SETUP: |
|
|
|
|
|
def setup_func(token_val, width, height, size, model_choice, model_path): |
|
|
try: |
|
|
StableDiffusion() |
|
|
except Exception as e: |
|
|
print(e) |
|
|
return {token: gr.update(value=str(e))} |
|
|
|
|
|
init_val = "patchmatch" |
|
|
return { |
|
|
token: gr.update(visible=False), |
|
|
canvas_width: gr.update(visible=False), |
|
|
canvas_height: gr.update(visible=False), |
|
|
selection_size: gr.update(visible=False), |
|
|
setup_button: gr.update(visible=False), |
|
|
frame: gr.update(visible=True), |
|
|
upload_button: gr.update(value="Upload Image"), |
|
|
model_selection: gr.update(visible=False), |
|
|
model_path_input: gr.update(visible=False), |
|
|
init_mode: gr.update(value=init_val), |
|
|
} |
|
|
|
|
|
setup_button.click( |
|
|
fn=setup_func, |
|
|
inputs=[ |
|
|
token, |
|
|
canvas_width, |
|
|
canvas_height, |
|
|
selection_size, |
|
|
model_selection, |
|
|
model_path_input, |
|
|
], |
|
|
outputs=[ |
|
|
token, |
|
|
canvas_width, |
|
|
canvas_height, |
|
|
selection_size, |
|
|
setup_button, |
|
|
frame, |
|
|
upload_button, |
|
|
model_selection, |
|
|
model_path_input, |
|
|
init_mode, |
|
|
], |
|
|
js=setup_button_js, |
|
|
) |
|
|
|
|
|
proceed_event = proceed_button.click( |
|
|
fn=run_outpaint, |
|
|
inputs=[ |
|
|
model_input, |
|
|
sd_prompt, |
|
|
sd_negative_prompt, |
|
|
sd_strength, |
|
|
sd_guidance, |
|
|
sd_step, |
|
|
sd_resize, |
|
|
init_mode, |
|
|
safety_check, |
|
|
postprocess_check, |
|
|
sd_img2img, |
|
|
sd_use_seed, |
|
|
sd_seed_val, |
|
|
sd_generate_num, |
|
|
sd_scheduler, |
|
|
sd_scheduler_eta, |
|
|
sd_controlnet_union, |
|
|
expand_mask_gui, |
|
|
color_selector, |
|
|
sd_scheduler_type, |
|
|
sd_prompt_weight, |
|
|
sd_image_resolution, |
|
|
sd_img_height, |
|
|
sd_img_width, |
|
|
sd_loraA, |
|
|
sd_loraAscale, |
|
|
interrogate_check, |
|
|
model_output_state, |
|
|
], |
|
|
outputs=[model_output, sd_prompt, model_output_state], |
|
|
js=proceed_button_js, |
|
|
) |
|
|
|
|
|
if tuple(map(int,gr.__version__.split("."))) >= (3,6): |
|
|
cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[proceed_event]) |
|
|
|
|
|
|
|
|
launch_extra_kwargs = { |
|
|
"show_error": True, |
|
|
|
|
|
} |
|
|
launch_kwargs = vars(args) |
|
|
launch_kwargs = {k: v for k, v in launch_kwargs.items() if v is not None} |
|
|
launch_kwargs.pop("remote_model", None) |
|
|
launch_kwargs.pop("local_model", None) |
|
|
launch_kwargs.pop("fp32", None) |
|
|
launch_kwargs.pop("lowvram", None) |
|
|
launch_kwargs.pop("stablepy_model", None) |
|
|
launch_kwargs.update(launch_extra_kwargs) |
|
|
try: |
|
|
import google.colab |
|
|
|
|
|
launch_kwargs["debug"] = True |
|
|
launch_kwargs["share"] = True |
|
|
launch_kwargs.pop("encrypt", None) |
|
|
except: |
|
|
launch_kwargs["share"] = False |
|
|
pass |
|
|
|
|
|
if not launch_kwargs["share"]: |
|
|
demo.launch() |
|
|
else: |
|
|
launch_kwargs["server_name"] = "0.0.0.0" |
|
|
demo.queue().launch(**launch_kwargs) |
|
|
|