File size: 5,933 Bytes
3dabe4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import os
import torch
import gradio as gr
import modules.scripts as scripts
import modules.shared as shared
from modules.script_callbacks import on_cfg_denoiser, remove_current_script_callbacks
from modules.prompt_parser import SdConditioning
class Script(scripts.Script):
def title(self):
return "Negative Prompt Weight Extention"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with gr.Accordion("Negative Prompt Weight", open=True, elem_id="npw"):
with gr.Row(equal_height=True):
with gr.Column(scale=100):
weight_input_slider = gr.Slider(minimum=0.00, maximum=2.00, step=.05, value=1.00, label="Weight", interactive=True, elem_id="npw-slider")
with gr.Column(scale=1, min_width=120):
with gr.Row():
weight_input = gr.Number(value=1.00, precision=4, label="Negative Prompt Weight", show_label=False, elem_id="npw-number")
reset_but = gr.Button(value='✕', elem_id='npw-x', size='sm')
js = """(v) => {
['#tab_txt2img #npw-x', '#tab_img2img #npw-x'].forEach((selector, index) => {
const element = document.querySelector(selector);
if (document.querySelector(`#tab_${index ? 'img2img' : 'txt2img'}`).style.display === 'block') {
element.style.cssText += `outline:4px solid rgba(255,186,0,${Math.sqrt(Math.abs(v-1))}); border-radius: 0.4em !important;`;
}
});
return v;
}"""
weight_input.change(None, [weight_input], weight_input_slider, _js=js)
weight_input_slider.change(None, weight_input_slider, weight_input, _js="(x) => x")
reset_but.click(None, [], [weight_input,weight_input_slider], _js="(x) => [1,1]")
self.infotext_fields = []
self.infotext_fields.extend([
(weight_input, "NPW_weight"),
])
self.paste_field_names = []
for _, field_name in self.infotext_fields:
self.paste_field_names.append(field_name)
return [weight_input]
def process(self, p, weight):
weight = getattr(p, 'NPW_weight', weight)
if weight != 1 : self.print_warning(weight)
self.width = p.width
self.height = p.height
self.weight = weight
self.empty_uncond = None
if hasattr(self, 'callbacks_added'):
remove_current_script_callbacks()
delattr(self, 'callbacks_added')
# print('NPW callback removed')
if self.weight != 1.0:
self.empty_uncond = self.make_empty_uncond(self.width, self.height)
on_cfg_denoiser(self.denoiser_callback)
# print('NPW callback added')
self.callbacks_added = True
p.extra_generation_params.update({
"NPW_weight": self.weight,
})
return
def postprocess(self, p, processed, *args):
if hasattr(self, 'callbacks_added'):
remove_current_script_callbacks()
delattr(self, 'callbacks_added')
# print('NPW callback removed in post')
def denoiser_callback(self, params):
def concat_and_lerp(empty, tensor, weight):
if empty.shape[0] != tensor.shape[0]:
empty = empty.expand(tensor.shape[0], *empty.shape[1:])
if tensor.shape[1] > empty.shape[1]:
num_concatenations = tensor.shape[1] // empty.shape[1]
empty_concat = torch.cat([empty] * num_concatenations, dim=1)
if tensor.shape[1] == empty_concat.shape[1] + 1:
# assuming it's controlnet's marks(?)
empty_concat = torch.cat([tensor[:, :1, :], empty_concat], dim=1)
new_tensor = torch.lerp(empty_concat, tensor, weight)
else:
new_tensor = torch.lerp(empty, tensor, weight)
return new_tensor
uncond = params.text_uncond
is_dict = isinstance(uncond, dict)
if type(self.empty_uncond) != type(uncond):
self.empty_uncond = self.make_empty_uncond(self.width, self.height)
empty_uncond = self.empty_uncond
if is_dict:
uncond, cross = uncond['vector'], uncond['crossattn']
empty_uncond, empty_cross = empty_uncond['vector'], empty_uncond['crossattn']
params.text_uncond['vector'] = concat_and_lerp(empty_uncond, uncond, self.weight)
params.text_uncond['crossattn'] = concat_and_lerp(empty_cross, cross, self.weight)
else:
params.text_uncond = concat_and_lerp(empty_uncond, uncond, self.weight)
def make_empty_uncond(self, w, h):
prompt = SdConditioning([""], is_negative_prompt=True, width=w, height=h)
empty_uncond = shared.sd_model.get_learned_conditioning(prompt)
return empty_uncond
def print_warning(self, value):
if value == 1:
return
color_code = '\033[33m'
if value < 0.5 or value > 1.5:
color_code = '\033[93m'
print(f"\n{color_code}ATTENTION: Negative prompt weight is set to {value}\033[0m")
def xyz_support():
for scriptDataTuple in scripts.scripts_data:
if os.path.basename(scriptDataTuple.path) == 'xyz_grid.py':
xy_grid = scriptDataTuple.module
npw_weight = xy_grid.AxisOption(
'[NPW] Weight',
float,
xy_grid.apply_field('NPW_weight')
)
xy_grid.axis_options.extend([
npw_weight
])
try:
xyz_support()
except Exception as e:
print(f'Error trying to add XYZ plot options for NPW', e)
|