File size: 5,933 Bytes
3dabe4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import torch
import gradio as gr

import modules.scripts as scripts
import modules.shared as shared
from modules.script_callbacks import on_cfg_denoiser, remove_current_script_callbacks
from modules.prompt_parser import SdConditioning


class Script(scripts.Script):

    def title(self):
        return "Negative Prompt Weight Extention"

    def show(self, is_img2img):
        return scripts.AlwaysVisible

    def ui(self, is_img2img):
        with gr.Accordion("Negative Prompt Weight", open=True, elem_id="npw"):                                                          
            with gr.Row(equal_height=True):
                with gr.Column(scale=100):
                    weight_input_slider = gr.Slider(minimum=0.00, maximum=2.00, step=.05, value=1.00, label="Weight", interactive=True, elem_id="npw-slider")
                with gr.Column(scale=1, min_width=120):
                    with gr.Row():
                        weight_input = gr.Number(value=1.00, precision=4, label="Negative Prompt Weight", show_label=False, elem_id="npw-number")   
                        reset_but = gr.Button(value='✕', elem_id='npw-x', size='sm')        

            js = """(v) => {
              ['#tab_txt2img #npw-x', '#tab_img2img #npw-x'].forEach((selector, index) => {
                const element = document.querySelector(selector);
                if (document.querySelector(`#tab_${index ? 'img2img' : 'txt2img'}`).style.display === 'block') {
                  element.style.cssText += `outline:4px solid rgba(255,186,0,${Math.sqrt(Math.abs(v-1))}); border-radius: 0.4em !important;`;
                }
              });
              return v;
            }"""
               
            weight_input.change(None, [weight_input], weight_input_slider, _js=js)
            weight_input_slider.change(None, weight_input_slider, weight_input, _js="(x) => x")
            reset_but.click(None, [], [weight_input,weight_input_slider], _js="(x) => [1,1]")


        self.infotext_fields = []        
        self.infotext_fields.extend([
            (weight_input, "NPW_weight"),
        ])
        self.paste_field_names = []
        for _, field_name in self.infotext_fields:
            self.paste_field_names.append(field_name)

        return [weight_input]
    

    def process(self, p, weight):    

        weight = getattr(p, 'NPW_weight', weight) 
        if weight != 1 : self.print_warning(weight)
        self.width = p.width
        self.height = p.height
        self.weight = weight
        self.empty_uncond = None      
        

        if hasattr(self, 'callbacks_added'):
            remove_current_script_callbacks()
            delattr(self, 'callbacks_added')
            # print('NPW callback removed')  

        if self.weight != 1.0:
            self.empty_uncond = self.make_empty_uncond(self.width, self.height) 
            on_cfg_denoiser(self.denoiser_callback)
            # print('NPW callback added')
            self.callbacks_added = True    

            p.extra_generation_params.update({
                "NPW_weight": self.weight,
            })

        return

    def postprocess(self, p, processed, *args):
        if hasattr(self, 'callbacks_added'):
            remove_current_script_callbacks()
            delattr(self, 'callbacks_added')
            # print('NPW callback removed in post')

    def denoiser_callback(self, params):
        def concat_and_lerp(empty, tensor, weight):
            if empty.shape[0] != tensor.shape[0]:
                empty = empty.expand(tensor.shape[0], *empty.shape[1:])
            if tensor.shape[1] > empty.shape[1]:
                num_concatenations = tensor.shape[1] // empty.shape[1]
                empty_concat = torch.cat([empty] * num_concatenations, dim=1)
                if tensor.shape[1] == empty_concat.shape[1] + 1:
                    # assuming it's controlnet's marks(?)
                    empty_concat = torch.cat([tensor[:, :1, :], empty_concat], dim=1)
                new_tensor = torch.lerp(empty_concat, tensor, weight)
            else:
                new_tensor = torch.lerp(empty, tensor, weight)
            return new_tensor

        uncond = params.text_uncond        
        is_dict = isinstance(uncond, dict)
        if type(self.empty_uncond) != type(uncond):
            self.empty_uncond = self.make_empty_uncond(self.width, self.height)
        empty_uncond = self.empty_uncond

        if is_dict:
            uncond, cross = uncond['vector'], uncond['crossattn']
            empty_uncond, empty_cross = empty_uncond['vector'], empty_uncond['crossattn']
            params.text_uncond['vector'] = concat_and_lerp(empty_uncond, uncond, self.weight)
            params.text_uncond['crossattn'] = concat_and_lerp(empty_cross, cross, self.weight)
        else:
            params.text_uncond = concat_and_lerp(empty_uncond, uncond, self.weight)

    def make_empty_uncond(self, w, h):
        prompt = SdConditioning([""], is_negative_prompt=True, width=w, height=h)
        empty_uncond = shared.sd_model.get_learned_conditioning(prompt)
        return empty_uncond

    def print_warning(self, value):
        if value == 1:
            return
        color_code = '\033[33m'  
        if value < 0.5 or value > 1.5:
            color_code = '\033[93m'  
        print(f"\n{color_code}ATTENTION: Negative prompt weight is set to {value}\033[0m")


def xyz_support():
    for scriptDataTuple in scripts.scripts_data:
        if os.path.basename(scriptDataTuple.path) == 'xyz_grid.py':
            xy_grid = scriptDataTuple.module

            npw_weight = xy_grid.AxisOption(
                '[NPW] Weight',
                float,
                xy_grid.apply_field('NPW_weight')
            )
            xy_grid.axis_options.extend([
                npw_weight
            ])
try:
    xyz_support()
except Exception as e:
    print(f'Error trying to add XYZ plot options for NPW', e)