dikdimon's picture
Upload extensions using SD-Hub extension
f4a41d8 verified
import os
import gradio as gr
from webui import wrap_gradio_gpu_call
import modules.scripts
from modules import script_callbacks, shared, sd_hijack
from modules.textual_inversion.textual_inversion import Embedding
def run_split(embedding_name, vector_index):
vector_index = int(vector_index)
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings[embedding_name]
vector = embedding.vec[vector_index]
new_name = f"{embedding_name}_{vector_index}"
filename = os.path.join(shared.cmd_opts.embeddings_dir, f"{new_name}.pt")
assert not os.path.exists(filename), f"file {filename} already exists"
print(f"Saving new embedding to {filename}\n")
split_embedding = Embedding(vector.unsqueeze(0), new_name)
split_embedding.step = 0
split_embedding.save(filename)
def add_tab():
with gr.Blocks(analytics_enabled=False) as ui:
embedding_name = gr.Dropdown(label='Embedding', elem_id="edit_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()), interactive=True)
vector_index = gr.Number(label='Vector', value=0, step=1, interactive=True)
go = gr.Button(value="Go", variant="primary")
go.click(
fn=run_split,
inputs=[embedding_name, vector_index],
outputs=[],
)
return [(ui, "Embedding Splitter", "embedding_splitter")]
script_callbacks.on_ui_tabs(add_tab)