Spaces:
Runtime error
Runtime error
Commit
Β·
4e75298
1
Parent(s):
266993c
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -218,3 +218,39 @@ token_emb_layer = text_encoder.text_model.embeddings.token_embedding
|
|
| 218 |
|
| 219 |
pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
|
| 220 |
#pos_emb_layer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
|
| 219 |
pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
|
| 220 |
#pos_emb_layer
|
| 221 |
+
|
| 222 |
+
def func_generate(query, concept_idx, seed_start, contrast_loss=False, contrast_perc=None):
|
| 223 |
+
prompt = query + ' in the style of bulb'
|
| 224 |
+
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True,
|
| 225 |
+
return_tensors="pt")
|
| 226 |
+
input_ids = text_input.input_ids.to(torch_device)
|
| 227 |
+
|
| 228 |
+
# Get token embeddings
|
| 229 |
+
position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
|
| 230 |
+
position_embeddings = pos_emb_layer(position_ids)
|
| 231 |
+
|
| 232 |
+
s = seed_start
|
| 233 |
+
|
| 234 |
+
token_embeddings = token_emb_layer(input_ids)
|
| 235 |
+
# The new embedding - our special birb word
|
| 236 |
+
replacement_token_embedding = concept_embeds[concept_idx].to(torch_device)
|
| 237 |
+
|
| 238 |
+
# Insert this into the token embeddings
|
| 239 |
+
token_embeddings[0, torch.where(input_ids[0] == 22373)] = replacement_token_embedding.to(torch_device)
|
| 240 |
+
|
| 241 |
+
# Combine with pos embs
|
| 242 |
+
input_embeddings = token_embeddings + position_embeddings
|
| 243 |
+
|
| 244 |
+
# Feed through to get final output embs
|
| 245 |
+
modified_output_embeddings = get_output_embeds(input_embeddings)
|
| 246 |
+
|
| 247 |
+
# And generate an image with this:
|
| 248 |
+
|
| 249 |
+
if contrast_loss and seed_values[concept_idx] > 0:
|
| 250 |
+
s = seed_values[concept_idx]
|
| 251 |
+
else:
|
| 252 |
+
s = random.randint(s + 1, s + 30)
|
| 253 |
+
seed_values[concept_idx] = s
|
| 254 |
+
|
| 255 |
+
g = torch.manual_seed(s)
|
| 256 |
+
return generate_with_embs(text_input, modified_output_embeddings, generator=g, contrast_loss=contrast_loss, contrast_perc=contrast_perc)
|