File size: 1,505 Bytes
f4a41d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
import gradio as gr
from webui import wrap_gradio_gpu_call
import modules.scripts
from modules import script_callbacks, shared, sd_hijack
from modules.textual_inversion.textual_inversion import Embedding

def run_split(embedding_name, vector_index):
    vector_index = int(vector_index)

    embedding = sd_hijack.model_hijack.embedding_db.word_embeddings[embedding_name]
    vector = embedding.vec[vector_index]
    
    new_name = f"{embedding_name}_{vector_index}"
    
    filename = os.path.join(shared.cmd_opts.embeddings_dir, f"{new_name}.pt")
    assert not os.path.exists(filename), f"file {filename} already exists"
    
    print(f"Saving new embedding to {filename}\n")
    
    split_embedding = Embedding(vector.unsqueeze(0), new_name)
    split_embedding.step = 0
    split_embedding.save(filename)
    

def add_tab():
    with gr.Blocks(analytics_enabled=False) as ui:
        embedding_name = gr.Dropdown(label='Embedding', elem_id="edit_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()), interactive=True)
        vector_index = gr.Number(label='Vector', value=0, step=1, interactive=True)
        
        go = gr.Button(value="Go", variant="primary")
        
        go.click(
            fn=run_split,
            inputs=[embedding_name, vector_index],
            outputs=[],
        )
    
    return [(ui, "Embedding Splitter", "embedding_splitter")]


script_callbacks.on_ui_tabs(add_tab)