Spaces:
Runtime error
Runtime error
Commit
Β·
266993c
1
Parent(s):
433ae3e
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from utils import
|
| 3 |
import random
|
| 4 |
|
| 5 |
is_clicked = False
|
|
@@ -89,43 +89,6 @@ with gr.Blocks() as app:
|
|
| 89 |
clear_btn2.click(clear_data2, None, [out11, out12, out13, out14, out15])
|
| 90 |
|
| 91 |
|
| 92 |
-
def func_generate(query, concept_idx, seed_start, contrast_loss=False, contrast_perc=None):
|
| 93 |
-
prompt = query + ' in the style of bulb'
|
| 94 |
-
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True,
|
| 95 |
-
return_tensors="pt")
|
| 96 |
-
input_ids = text_input.input_ids.to(torch_device)
|
| 97 |
-
|
| 98 |
-
# Get token embeddings
|
| 99 |
-
position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
|
| 100 |
-
position_embeddings = pos_emb_layer(position_ids)
|
| 101 |
-
|
| 102 |
-
s = seed_start
|
| 103 |
-
|
| 104 |
-
token_embeddings = token_emb_layer(input_ids)
|
| 105 |
-
# The new embedding - our special birb word
|
| 106 |
-
replacement_token_embedding = concept_embeds[concept_idx].to(torch_device)
|
| 107 |
-
|
| 108 |
-
# Insert this into the token embeddings
|
| 109 |
-
token_embeddings[0, torch.where(input_ids[0] == 22373)] = replacement_token_embedding.to(torch_device)
|
| 110 |
-
|
| 111 |
-
# Combine with pos embs
|
| 112 |
-
input_embeddings = token_embeddings + position_embeddings
|
| 113 |
-
|
| 114 |
-
# Feed through to get final output embs
|
| 115 |
-
modified_output_embeddings = get_output_embeds(input_embeddings)
|
| 116 |
-
|
| 117 |
-
# And generate an image with this:
|
| 118 |
-
|
| 119 |
-
if contrast_loss and seed_values[concept_idx] > 0:
|
| 120 |
-
s = seed_values[concept_idx]
|
| 121 |
-
else:
|
| 122 |
-
s = random.randint(s + 1, s + 30)
|
| 123 |
-
seed_values[concept_idx] = s
|
| 124 |
-
|
| 125 |
-
g = torch.manual_seed(s)
|
| 126 |
-
return generate_with_embs(text_input, modified_output_embeddings, generator=g, contrast_loss=contrast_loss, contrast_perc=contrast_perc)
|
| 127 |
-
|
| 128 |
-
|
| 129 |
def generate_image(query, con_idx, o1, o2, o3, o4, o5, contrast):
|
| 130 |
if not query:
|
| 131 |
raise gr.Error("No prompt provided")
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from utils import func_generate
|
| 3 |
import random
|
| 4 |
|
| 5 |
is_clicked = False
|
|
|
|
| 89 |
clear_btn2.click(clear_data2, None, [out11, out12, out13, out14, out15])
|
| 90 |
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
def generate_image(query, con_idx, o1, o2, o3, o4, o5, contrast):
|
| 93 |
if not query:
|
| 94 |
raise gr.Error("No prompt provided")
|