File size: 9,555 Bytes
462d27c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f8ff1b
462d27c
 
 
 
 
 
 
 
0f8ff1b
 
462d27c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f8ff1b
462d27c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8becec0
8af685e
462d27c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
from model import models
from multit2i import (load_models, infer_fn, infer_rand_fn, save_gallery,
    change_model, warm_model, get_model_info_md, loaded_models, warm_models,
    get_positive_prefix, get_positive_suffix, get_negative_prefix, get_negative_suffix,
    get_recom_prompt_type, set_recom_prompt_preset, get_tag_type, randomize_seed, translate_to_en)
from tagger.tagger import (predict_tags_wd, remove_specific_prompt, convert_danbooru_to_e621_prompt,
    insert_recom_prompt, compose_prompt_to_copy)
from tagger.fl2sd3longcap import predict_tags_fl2_sd3
from tagger.v2 import V2_ALL_MODELS, v2_random_prompt
from tagger.utils import (V2_ASPECT_RATIO_OPTIONS, V2_RATING_OPTIONS,
    V2_LENGTH_OPTIONS, V2_IDENTITY_OPTIONS)

max_images = 6
MAX_SEED = 2**32-1
load_models(models)
warm_models(models[0:max_images])

css = """
.title { font-size: 3em; align-items: center; text-align: center; }
.info { align-items: center; text-align: center; }
.model_info { text-align: center; }
.output { width=112px; height=112px; max_width=112px; max_height=112px; !important; }
.gallery { min_width=512px; min_height=512px; max_height=1024px; !important; }
"""

with gr.Blocks(theme="NoCrypt/miku@>=1.2.2", fill_width=True, css=css) as demo:
    with gr.Tab(""):
        with gr.Row():
            with gr.Column(scale=10): 
                with gr.Group():
                    with gr.Accordion("Prompt Transformer", open=False):
                        with gr.Row(equal_height=True):
                            v2_rating = gr.Radio(label="Rating", choices=list(V2_RATING_OPTIONS), value="sfw")
                            v2_aspect_ratio = gr.Radio(label="Aspect ratio", info="The aspect ratio of the image.", choices=list(V2_ASPECT_RATIO_OPTIONS), value="square", visible=False)
                            v2_length = gr.Radio(label="Length", info="The total length of the tags.", choices=list(V2_LENGTH_OPTIONS), value="long")
                       
                        
                    prompt = gr.Text(label="Prompt", lines=2, max_lines=8, placeholder="1girl, solo, ...", show_copy_button=True)
                    with gr.Accordion("Advanced options", open=False):
                        neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="")
                        with gr.Row(equal_height=True):
                            width = gr.Slider(label="Width",  maximum=1216, step=32, value=0)
                            height = gr.Slider(label="Height",  maximum=1216, step=32, value=0)
                            steps = gr.Slider(label="Number of inference steps", maximum=100, step=1, value=0)
                        with gr.Row(equal_height=True):
                            cfg = gr.Slider(label="Guidance scale", maximum=30.0, step=0.1, value=0)
                            seed = gr.Slider(label="Seed",  minimum=-1, maximum=MAX_SEED, step=1, value=-1)
                            seed_rand = gr.Button("Randomize Seed 🎲", size="sm", variant="secondary")
                        recom_prompt_preset = gr.Radio(label="Set Presets", choices=get_recom_prompt_type(), value="Common")
                        with gr.Row(equal_height=True):
                            positive_prefix = gr.CheckboxGroup(label="Use Positive Prefix", choices=get_positive_prefix(), value=[])
                            positive_suffix = gr.CheckboxGroup(label="Use Positive Suffix", choices=get_positive_suffix(), value=["Common"])
                            negative_prefix = gr.CheckboxGroup(label="Use Negative Prefix", choices=get_negative_prefix(), value=[])
                            negative_suffix = gr.CheckboxGroup(label="Use Negative Suffix", choices=get_negative_suffix(), value=["Common"])
                    with gr.Row(equal_height=True):
                        image_num = gr.Slider(label="Number of images", minimum=1, maximum=max_images, value=1, step=1, interactive=True, scale=2)
                                       
                with gr.Row(equal_height=True):
                    run_button = gr.Button("Generate Image", variant="primary", scale=6)
                    random_button = gr.Button("Random Model 🎲", variant="secondary", scale=3)
                    #stop_button = gr.Button('Stop', variant="stop", interactive=False, scale=1)
                with gr.Group():
                    model_name = gr.Dropdown(label="Select Model", choices=list(loaded_models.keys()), value=list(loaded_models.keys())[0], allow_custom_value=True)
                    model_info = gr.Markdown(value=get_model_info_md(list(loaded_models.keys())[0]), elem_classes="model_info")
            with gr.Column(scale=10): 
                with gr.Group():
                    with gr.Row():
                        output = [gr.Image(label='', elem_classes="output", type="filepath", format="png",
                                show_download_button=True, show_share_button=False, show_label=False,
                                interactive=False, min_width=80, visible=True, width=112, height=112) for _ in range(max_images)]
                with gr.Group():
                    results = gr.Gallery(label="Gallery", elem_classes="gallery", interactive=False, show_download_button=True, show_share_button=False,
                                        container=True, format="png", object_fit="cover", columns=2, rows=2)
                    image_files = gr.Files(label="Download", interactive=False)
                    clear_results = gr.Button("Clear Gallery / Download 🗑️", variant="secondary")
       
   
    #gr.on(triggers=[run_button.click, prompt.submit, random_button.click], fn=lambda: gr.update(interactive=True), inputs=None, outputs=stop_button, show_api=False)
    model_name.change(change_model, [model_name], [model_info], queue=False, show_api=False)\
    .success(warm_model, [model_name], None, queue=False, show_api=False)
    for i, o in enumerate(output):
        img_i = gr.Number(i, visible=False)
        image_num.change(lambda i, n: gr.update(visible = (i < n)), [img_i, image_num], o, show_api=False)
        gen_event = gr.on(triggers=[run_button.click, prompt.submit],
         fn=lambda i, n, m, t1, t2, n1, n2, n3, n4, n5, l1, l2, l3, l4: infer_fn(m, t1, t2, n1, n2, n3, n4, n5, l1, l2, l3, l4) if (i < n) else None,
         inputs=[img_i, image_num, model_name, prompt, neg_prompt, height, width, steps, cfg, seed,
                  positive_prefix, positive_suffix, negative_prefix, negative_suffix],
         outputs=[o], queue=False, show_api=False)  # Be sure to delete ", queue=False" when activating the stop button
        gen_event2 = gr.on(triggers=[random_button.click],
         fn=lambda i, n, m, t1, t2, n1, n2, n3, n4, n5, l1, l2, l3, l4: infer_rand_fn(m, t1, t2, n1, n2, n3, n4, n5, l1, l2, l3, l4) if (i < n) else None,
         inputs=[img_i, image_num, model_name, prompt, neg_prompt, height, width, steps, cfg, seed,
                  positive_prefix, positive_suffix, negative_prefix, negative_suffix],
         outputs=[o], queue=False, show_api=False)  # Be sure to delete ", queue=False" when activating the stop button
        o.change(save_gallery, [o, results], [results, image_files], show_api=False)
        #stop_button.click(lambda: gr.update(interactive=False), None, stop_button, cancels=[gen_event, gen_event2], show_api=False)

    clear_prompt.click(lambda: (None, None, None, None), None, [prompt, neg_prompt, v2_character, v2_series], queue=False, show_api=False)
    clear_results.click(lambda: (None, None), None, [results, image_files], queue=False, show_api=False)
    recom_prompt_preset.change(set_recom_prompt_preset, [recom_prompt_preset],
     [positive_prefix, positive_suffix, negative_prefix, negative_suffix], queue=False, show_api=False)
    seed_rand.click(randomize_seed, None, [seed], queue=False, show_api=False)
    trans_prompt.click(translate_to_en, [prompt], [prompt], queue=False, show_api=False)\
    .then(translate_to_en, [neg_prompt], [neg_prompt], queue=False, show_api=False)

    random_prompt.click(
        v2_random_prompt, [prompt, v2_series, v2_character, v2_rating, v2_aspect_ratio, v2_length,
          v2_identity, v2_ban_tags, v2_model], [prompt, v2_series, v2_character], show_api=False,
    ).success(get_tag_type, [positive_prefix, positive_suffix, negative_prefix, negative_suffix], [v2_tag_type], queue=False, show_api=False
    ).success(convert_danbooru_to_e621_prompt, [prompt, v2_tag_type], [prompt], queue=False, show_api=False)
    tagger_generate_from_image.click(lambda: ("", "", ""), None, [v2_series, v2_character, prompt], queue=False, show_api=False,
    ).success(
        predict_tags_wd,
        [tagger_image, prompt, tagger_algorithms, tagger_general_threshold, tagger_character_threshold],
        [v2_series, v2_character, prompt, v2_copy],
        show_api=False,
    ).success(predict_tags_fl2_sd3, [tagger_image, prompt, tagger_algorithms], [prompt], show_api=False,
    ).success(remove_specific_prompt, [prompt, tagger_keep_tags], [prompt], queue=False, show_api=False,
    ).success(convert_danbooru_to_e621_prompt, [prompt, tagger_tag_type], [prompt], queue=False, show_api=False,
    ).success(insert_recom_prompt, [prompt, neg_prompt, tagger_recom_prompt], [prompt, neg_prompt], queue=False, show_api=False,
    ).success(compose_prompt_to_copy, [v2_character, v2_series, prompt], [prompt], queue=False, show_api=False)

#demo.queue(default_concurrency_limit=200, max_size=200)
demo.launch(max_threads=400, ssr_mode=False)