Spaces:
Runtime error
Runtime error
Commit
Β·
e55aab4
1
Parent(s):
430249a
update sol probs others
Browse files- .gitignore +1 -0
- app.py +224 -211
- twitter_prompts.csv +47 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__*
|
app.py
CHANGED
|
@@ -6,21 +6,27 @@ import torch
|
|
| 6 |
|
| 7 |
# lol
|
| 8 |
DEVICE = 'cuda'
|
| 9 |
-
STEPS =
|
| 10 |
output_hidden_state = False
|
| 11 |
device = "cuda"
|
| 12 |
dtype = torch.bfloat16
|
| 13 |
-
N_IMG_EMBS = 3
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
import logging
|
|
|
|
| 16 |
import os
|
| 17 |
import imageio
|
| 18 |
import gradio as gr
|
| 19 |
import numpy as np
|
| 20 |
-
from sklearn.svm import
|
| 21 |
-
from sklearn import preprocessing
|
| 22 |
import pandas as pd
|
| 23 |
from apscheduler.schedulers.background import BackgroundScheduler
|
|
|
|
|
|
|
| 24 |
|
| 25 |
import random
|
| 26 |
import time
|
|
@@ -37,8 +43,12 @@ prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'l
|
|
| 37 |
import spaces
|
| 38 |
start_time = time.time()
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
####################### Setup Model
|
| 41 |
-
from diffusers import
|
| 42 |
from transformers import CLIPTextModel
|
| 43 |
from huggingface_hub import hf_hub_download
|
| 44 |
from safetensors.torch import load_file
|
|
@@ -46,9 +56,8 @@ from PIL import Image
|
|
| 46 |
from transformers import CLIPVisionModelWithProjection
|
| 47 |
import uuid
|
| 48 |
import av
|
| 49 |
-
import torchvision
|
| 50 |
|
| 51 |
-
def write_video(file_name, images, fps=
|
| 52 |
container = av.open(file_name, mode="w")
|
| 53 |
|
| 54 |
stream = container.add_stream("h264", rate=fps)
|
|
@@ -89,182 +98,133 @@ device_map='cuda')
|
|
| 89 |
#unet = UNet2DConditionModel.from_pretrained(finetune_path+'/unet/').to(dtype)
|
| 90 |
#text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype)
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
-
|
| 98 |
-
pipe =
|
| 99 |
-
unet=unet, text_encoder=text_encoder)
|
| 100 |
-
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
|
| 101 |
-
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
|
| 102 |
-
pipe.set_adapters(["lcm-lora"], [.95])
|
| 103 |
-
pipe.fuse_lora()
|
| 104 |
|
|
|
|
|
|
|
| 105 |
|
| 106 |
-
#pipe = AnimateDiffPipeline.from_pretrained('emilianJR/epiCRealism', torch_dtype=dtype, image_encoder=image_encoder)
|
| 107 |
-
#pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
| 108 |
-
#repo = "ByteDance/AnimateDiff-Lightning"
|
| 109 |
-
#ckpt = f"animatediff_lightning_4step_diffusers.safetensors"
|
| 110 |
|
| 111 |
|
| 112 |
-
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15_vit-G.bin", map_location='cpu')
|
| 113 |
-
# This IP adapter improves outputs substantially.
|
| 114 |
-
pipe.set_ip_adapter_scale(.6)
|
| 115 |
-
pipe.unet.fuse_qkv_projections()
|
| 116 |
-
#pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
|
| 117 |
|
| 118 |
-
pipe.
|
|
|
|
| 119 |
|
| 120 |
#pipe.unet = torch.compile(pipe.unet)
|
| 121 |
#pipe.vae = torch.compile(pipe.vae)
|
| 122 |
|
| 123 |
|
| 124 |
-
#############################################################
|
| 125 |
-
|
| 126 |
-
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
| 127 |
-
|
| 128 |
-
pali = PaliGemmaForConditionalGeneration.from_pretrained('google/paligemma-3b-mix-224', torch_dtype=dtype).eval().to('cuda')
|
| 129 |
-
processor = AutoProcessor.from_pretrained('google/paligemma-3b-mix-224')
|
| 130 |
-
|
| 131 |
-
#pali = torch.compile(pali)
|
| 132 |
|
| 133 |
@spaces.GPU()
|
| 134 |
-
def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None):
|
| 135 |
-
inputs_embeds = pali.get_input_embeddings()(input_ids.to('cuda'))
|
| 136 |
-
selected_image_feature = image_outputs.to(dtype).to('cuda')
|
| 137 |
-
image_features = pali.multi_modal_projector(selected_image_feature)
|
| 138 |
-
|
| 139 |
-
if cache_position is None:
|
| 140 |
-
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
|
| 141 |
-
inputs_embeds, attention_mask, labels, position_ids = pali._merge_input_ids_with_image_features(
|
| 142 |
-
image_features, inputs_embeds, input_ids, attention_mask, None, None, cache_position
|
| 143 |
-
)
|
| 144 |
-
return inputs_embeds
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
# TODO cache descriptions?
|
| 148 |
-
@spaces.GPU(duration=20)
|
| 149 |
-
def generate_pali(n_embs):
|
| 150 |
-
prompt = 'caption en'
|
| 151 |
-
model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
|
| 152 |
-
# we need to get im_embs taken in here.
|
| 153 |
-
|
| 154 |
-
descs = ''
|
| 155 |
-
for n, emb in enumerate(n_embs):
|
| 156 |
-
if n < len(n_embs)-1:
|
| 157 |
-
input_len = model_inputs["input_ids"].shape[-1]
|
| 158 |
-
input_embeds = to_wanted_embs(emb,
|
| 159 |
-
model_inputs["input_ids"].to(device),
|
| 160 |
-
model_inputs["attention_mask"].to(device))
|
| 161 |
-
generation = pali.generate(max_new_tokens=20, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
|
| 162 |
-
decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
| 163 |
-
descs += f'Description: {decoded}\n'
|
| 164 |
-
else:
|
| 165 |
-
prompt = f'en {descs} Describe a new image that is similar. Description:'
|
| 166 |
-
model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
|
| 167 |
-
input_len = model_inputs["input_ids"].shape[-1]
|
| 168 |
-
input_embeds = to_wanted_embs(emb,
|
| 169 |
-
model_inputs["input_ids"].to(device),
|
| 170 |
-
model_inputs["attention_mask"].to(device))
|
| 171 |
-
generation = pali.generate(max_new_tokens=20, do_sample=True, top_p=.94, temperature=1.2, inputs_embeds=input_embeds)
|
| 172 |
-
decoded = processor.decode(generation[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
| 173 |
-
|
| 174 |
-
return decoded
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
#############################################################
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
@spaces.GPU(duration=20)
|
| 184 |
def generate_gpu(in_im_embs, prompt='the scene'):
|
| 185 |
with torch.no_grad():
|
| 186 |
-
|
|
|
|
| 187 |
output = pipe(prompt=prompt, guidance_scale=1, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
|
| 188 |
im_emb, _ = pipe.encode_image(
|
| 189 |
-
output.
|
| 190 |
)
|
| 191 |
im_emb = im_emb.detach().to('cpu').to(torch.float32)
|
| 192 |
-
|
| 193 |
-
im = torch.nn.functional.interpolate(im, (224, 224)).to(dtype).to('cuda')
|
| 194 |
-
im = (im - .5) * 2
|
| 195 |
-
gemb = pali.vision_tower(im).last_hidden_state.detach().to('cpu').to(torch.float32)
|
| 196 |
-
return output, im_emb, gemb
|
| 197 |
|
| 198 |
|
| 199 |
def generate(in_im_embs, prompt='the scene'):
|
| 200 |
-
output, im_emb
|
| 201 |
-
nsfw =maybe_nsfw(output.
|
|
|
|
| 202 |
name = str(uuid.uuid4()).replace("-", "")
|
| 203 |
-
path = f"/tmp/{name}.
|
| 204 |
|
| 205 |
if nsfw:
|
| 206 |
gr.Warning("NSFW content detected.")
|
| 207 |
# TODO could return an automatic dislike of auto dislike on the backend for neither as well; just would need refactoring.
|
| 208 |
-
return None, im_emb
|
| 209 |
-
|
| 210 |
|
| 211 |
-
output.
|
| 212 |
-
|
| 213 |
-
write_video(path, output.frames[0])
|
| 214 |
-
return path, im_emb, gemb
|
| 215 |
|
| 216 |
|
| 217 |
#######################
|
| 218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
def get_user_emb(embs, ys):
|
| 220 |
-
# handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
|
| 221 |
-
|
| 222 |
-
if len(list(ys)) <= 10:
|
| 223 |
-
aways = [torch.zeros_like(embs[0]) for i in range(10)]
|
| 224 |
-
embs += aways
|
| 225 |
-
awal = [0 for i in range(5)] + [1 for i in range(5)]
|
| 226 |
-
ys += awal
|
| 227 |
-
|
| 228 |
-
indices = list(range(len(embs)))
|
| 229 |
# sample only as many negatives as there are positives
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
#neg_indices = random.sample(neg_indices, lower)
|
| 234 |
-
#pos_indices = random.sample(pos_indices, lower)
|
| 235 |
-
|
| 236 |
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
#
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
-
|
| 250 |
-
|
| 251 |
|
| 252 |
-
#
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
coef_ = coef_ / coef_.abs().max()
|
| 257 |
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
return im_emb
|
| 261 |
|
| 262 |
|
| 263 |
def pluck_img(user_id, user_emb):
|
| 264 |
not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
|
| 265 |
while len(not_rated_rows) == 0:
|
| 266 |
not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
|
| 267 |
-
time.sleep(.
|
| 268 |
# TODO optimize this lol
|
| 269 |
best_sim = -100000
|
| 270 |
for i in not_rated_rows.iterrows():
|
|
@@ -274,8 +234,7 @@ def pluck_img(user_id, user_emb):
|
|
| 274 |
best_sim = sim
|
| 275 |
best_row = i[1]
|
| 276 |
img = best_row['paths']
|
| 277 |
-
|
| 278 |
-
return img, text
|
| 279 |
|
| 280 |
|
| 281 |
def background_next_image():
|
|
@@ -283,10 +242,10 @@ def background_next_image():
|
|
| 283 |
# only let it get N (maybe 3) ahead of the user
|
| 284 |
#not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
|
| 285 |
rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
|
| 286 |
-
|
|
|
|
| 287 |
# not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
|
| 288 |
-
|
| 289 |
-
time.sleep(.01)
|
| 290 |
|
| 291 |
user_id_list = set(rated_rows['latest_user_to_rate'].to_list())
|
| 292 |
for uid in user_id_list:
|
|
@@ -300,22 +259,32 @@ def background_next_image():
|
|
| 300 |
rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]]
|
| 301 |
|
| 302 |
# we pop previous ratings if there are > n
|
| 303 |
-
if len(rated_from_user) >=
|
| 304 |
oldest = rated_from_user.iloc[0]['paths']
|
| 305 |
prevs_df = prevs_df[prevs_df['paths'] != oldest]
|
| 306 |
# we don't compute more after n are in the queue for them
|
| 307 |
-
if len(unrated_from_user) >=
|
| 308 |
continue
|
| 309 |
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
else:
|
| 317 |
-
text = '
|
| 318 |
-
img, embs
|
| 319 |
|
| 320 |
if img:
|
| 321 |
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'text', 'gemb'])
|
|
@@ -324,7 +293,6 @@ def background_next_image():
|
|
| 324 |
tmp_df['user:rating'] = [{' ': ' '}]
|
| 325 |
tmp_df['from_user_id'] = [uid]
|
| 326 |
tmp_df['text'] = [text]
|
| 327 |
-
tmp_df['gemb'] = [new_gem]
|
| 328 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
| 329 |
# we can free up storage by deleting the image
|
| 330 |
if len(prevs_df) > 500:
|
|
@@ -340,37 +308,52 @@ def background_next_image():
|
|
| 340 |
|
| 341 |
def pluck_embs_ys(user_id):
|
| 342 |
rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
|
| 344 |
embs = rated_rows['embeddings'].to_list()
|
| 345 |
ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
|
| 346 |
-
|
| 347 |
-
return embs, ys, gembs
|
| 348 |
|
| 349 |
def next_image(calibrate_prompts, user_id):
|
| 350 |
with torch.no_grad():
|
| 351 |
if len(calibrate_prompts) > 0:
|
| 352 |
cal_video = calibrate_prompts.pop(0)
|
| 353 |
image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
|
| 354 |
-
return image, calibrate_prompts,
|
| 355 |
else:
|
| 356 |
-
embs, ys
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
|
| 361 |
|
| 362 |
|
| 363 |
def start(_, calibrate_prompts, user_id, request: gr.Request):
|
| 364 |
user_id = int(str(time.time())[-7:].replace('.', ''))
|
| 365 |
-
image, calibrate_prompts
|
| 366 |
return [
|
| 367 |
-
gr.Button(value='
|
| 368 |
gr.Button(value='Neither (Space)', interactive=True, visible=False),
|
| 369 |
-
gr.Button(value='
|
| 370 |
gr.Button(value='Start', interactive=False),
|
|
|
|
|
|
|
| 371 |
image,
|
| 372 |
calibrate_prompts,
|
| 373 |
-
user_id
|
|
|
|
| 374 |
]
|
| 375 |
|
| 376 |
|
|
@@ -378,27 +361,34 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
|
|
| 378 |
global prevs_df
|
| 379 |
|
| 380 |
|
| 381 |
-
if choice == '
|
| 382 |
-
choice = 1
|
| 383 |
elif choice == 'Neither (Space)':
|
| 384 |
-
img, calibrate_prompts,
|
| 385 |
-
return img, calibrate_prompts,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
else:
|
| 387 |
-
choice
|
| 388 |
|
| 389 |
# if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
|
| 390 |
# TODO skip allowing rating & just continue
|
| 391 |
-
|
|
|
|
| 392 |
print('NSFW -- choice is disliked')
|
| 393 |
-
choice = 0
|
| 394 |
|
| 395 |
row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
|
| 396 |
# if it's still in the dataframe, add the choice
|
| 397 |
if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
|
| 398 |
prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
|
| 399 |
prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
|
| 400 |
-
img, calibrate_prompts,
|
| 401 |
-
return img, calibrate_prompts
|
| 402 |
|
| 403 |
css = '''.gradio-container{max-width: 700px !important}
|
| 404 |
#description{text-align: center}
|
|
@@ -461,53 +451,71 @@ Explore the latent space without text prompts based on your preferences. Learn m
|
|
| 461 |
user_id = gr.State()
|
| 462 |
# calibration videos -- this is a misnomer now :D
|
| 463 |
calibrate_prompts = gr.State([
|
| 464 |
-
'./first.
|
| 465 |
-
'./second.
|
| 466 |
-
'./
|
| 467 |
-
'./
|
| 468 |
-
'./
|
| 469 |
-
'./sixth.mp4',
|
| 470 |
])
|
| 471 |
def l():
|
| 472 |
return None
|
| 473 |
|
| 474 |
with gr.Row(elem_id='output-image'):
|
| 475 |
-
img = gr.
|
| 476 |
label='Lightning',
|
| 477 |
-
autoplay=True,
|
| 478 |
interactive=False,
|
| 479 |
-
height=512,
|
| 480 |
-
width=512,
|
| 481 |
#include_audio=False,
|
| 482 |
-
elem_id="video_output"
|
|
|
|
| 483 |
)
|
| 484 |
-
img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
|
| 485 |
-
|
| 486 |
-
|
|
|
|
| 487 |
with gr.Row(equal_height=True):
|
| 488 |
-
b3 = gr.Button(value='
|
|
|
|
| 489 |
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False)
|
| 490 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
b1.click(
|
| 492 |
choose,
|
| 493 |
[img, b1, calibrate_prompts, user_id],
|
| 494 |
-
[img, calibrate_prompts,
|
| 495 |
)
|
| 496 |
b2.click(
|
| 497 |
choose,
|
| 498 |
[img, b2, calibrate_prompts, user_id],
|
| 499 |
-
[img, calibrate_prompts,
|
| 500 |
)
|
| 501 |
b3.click(
|
| 502 |
choose,
|
| 503 |
[img, b3, calibrate_prompts, user_id],
|
| 504 |
-
[img, calibrate_prompts,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
)
|
| 506 |
with gr.Row():
|
| 507 |
b4 = gr.Button(value='Start')
|
| 508 |
b4.click(start,
|
| 509 |
[b4, calibrate_prompts, user_id],
|
| 510 |
-
[b1, b2, b3, b4, img, calibrate_prompts, user_id]
|
| 511 |
)
|
| 512 |
with gr.Row():
|
| 513 |
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several videos and then roam. </ div><br><br><br>
|
|
@@ -518,37 +526,42 @@ Explore the latent space without text prompts based on your preferences. Learn m
|
|
| 518 |
</ div>''')
|
| 519 |
|
| 520 |
# TODO quiet logging
|
| 521 |
-
log = logging.getLogger('log_here')
|
| 522 |
-
log.setLevel(logging.ERROR)
|
| 523 |
|
| 524 |
scheduler = BackgroundScheduler()
|
| 525 |
-
scheduler.add_job(func=background_next_image, trigger="interval", seconds=.
|
| 526 |
scheduler.start()
|
| 527 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
|
| 529 |
# prep our calibration videos
|
| 530 |
-
for im in [
|
| 531 |
-
'./first.
|
| 532 |
-
'./second.
|
| 533 |
-
'./
|
| 534 |
-
'./
|
| 535 |
-
'./
|
| 536 |
-
'./sixth.mp4',
|
| 537 |
-
'./seventh.mp4',
|
| 538 |
-
'./eigth.mp4',
|
| 539 |
-
'./ninth.mp4',
|
| 540 |
-
'./tenth.mp4',
|
| 541 |
]:
|
| 542 |
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb'])
|
| 543 |
tmp_df['paths'] = [im]
|
| 544 |
image = list(imageio.imiter(im))
|
| 545 |
image = image[len(image)//2]
|
| 546 |
-
|
| 547 |
-
|
|
|
|
| 548 |
tmp_df['user:rating'] = [{' ': ' '}]
|
|
|
|
| 549 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
| 550 |
|
| 551 |
-
|
| 552 |
-
demo.launch(share=True)
|
| 553 |
|
| 554 |
|
|
|
|
| 6 |
|
| 7 |
# lol
|
| 8 |
DEVICE = 'cuda'
|
| 9 |
+
STEPS = 8
|
| 10 |
output_hidden_state = False
|
| 11 |
device = "cuda"
|
| 12 |
dtype = torch.bfloat16
|
|
|
|
| 13 |
|
| 14 |
+
|
| 15 |
+
import spaces
|
| 16 |
+
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
import matplotlib
|
| 19 |
import logging
|
| 20 |
+
|
| 21 |
import os
|
| 22 |
import imageio
|
| 23 |
import gradio as gr
|
| 24 |
import numpy as np
|
| 25 |
+
from sklearn.svm import LinearSVC
|
|
|
|
| 26 |
import pandas as pd
|
| 27 |
from apscheduler.schedulers.background import BackgroundScheduler
|
| 28 |
+
import sched
|
| 29 |
+
import threading
|
| 30 |
|
| 31 |
import random
|
| 32 |
import time
|
|
|
|
| 43 |
import spaces
|
| 44 |
start_time = time.time()
|
| 45 |
|
| 46 |
+
prompt_list = [p for p in list(set(
|
| 47 |
+
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
####################### Setup Model
|
| 51 |
+
from diffusers import EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL, AutoPipelineForText2Image
|
| 52 |
from transformers import CLIPTextModel
|
| 53 |
from huggingface_hub import hf_hub_download
|
| 54 |
from safetensors.torch import load_file
|
|
|
|
| 56 |
from transformers import CLIPVisionModelWithProjection
|
| 57 |
import uuid
|
| 58 |
import av
|
|
|
|
| 59 |
|
| 60 |
+
def write_video(file_name, images, fps=16):
|
| 61 |
container = av.open(file_name, mode="w")
|
| 62 |
|
| 63 |
stream = container.add_stream("h264", rate=fps)
|
|
|
|
| 98 |
#unet = UNet2DConditionModel.from_pretrained(finetune_path+'/unet/').to(dtype)
|
| 99 |
#text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype)
|
| 100 |
|
| 101 |
+
#rynmurdock/Sea_Claws
|
| 102 |
+
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 103 |
+
sdxl_lightening = "ByteDance/SDXL-Lightning"
|
| 104 |
+
ckpt = "sdxl_lightning_8step_unet.safetensors"
|
| 105 |
+
unet = UNet2DConditionModel.from_config(model_id, subfolder="unet", low_cpu_mem_usage=True, device_map=DEVICE).to(torch.float16)
|
| 106 |
+
unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt)))
|
| 107 |
|
| 108 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map=DEVICE)
|
| 109 |
+
pipe = AutoPipelineForText2Image.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16", image_encoder=image_encoder, low_cpu_mem_usage=True)
|
| 110 |
+
pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl_vit-h.bin')))
|
| 111 |
+
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl_vit-h.bin")
|
| 112 |
+
pipe.register_modules(image_encoder = image_encoder)
|
| 113 |
+
pipe.set_ip_adapter_scale(0.8)
|
| 114 |
|
| 115 |
+
#pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
| 116 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
+
pipe.to(device=DEVICE).to(dtype=dtype)
|
| 119 |
+
output_hidden_state = False
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
+
# pipe.unet.fuse_qkv_projections()
|
| 125 |
+
#pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
|
| 126 |
|
| 127 |
#pipe.unet = torch.compile(pipe.unet)
|
| 128 |
#pipe.vae = torch.compile(pipe.vae)
|
| 129 |
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
@spaces.GPU()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
def generate_gpu(in_im_embs, prompt='the scene'):
|
| 134 |
with torch.no_grad():
|
| 135 |
+
print(prompt)
|
| 136 |
+
in_im_embs = in_im_embs.to('cuda').unsqueeze(0)
|
| 137 |
output = pipe(prompt=prompt, guidance_scale=1, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
|
| 138 |
im_emb, _ = pipe.encode_image(
|
| 139 |
+
output.images[0], 'cuda', 1, output_hidden_state
|
| 140 |
)
|
| 141 |
im_emb = im_emb.detach().to('cpu').to(torch.float32)
|
| 142 |
+
return output, im_emb
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
|
| 145 |
def generate(in_im_embs, prompt='the scene'):
|
| 146 |
+
output, im_emb = generate_gpu(in_im_embs, prompt)
|
| 147 |
+
nsfw = maybe_nsfw(output.images[0])
|
| 148 |
+
|
| 149 |
name = str(uuid.uuid4()).replace("-", "")
|
| 150 |
+
path = f"/tmp/{name}.png"
|
| 151 |
|
| 152 |
if nsfw:
|
| 153 |
gr.Warning("NSFW content detected.")
|
| 154 |
# TODO could return an automatic dislike of auto dislike on the backend for neither as well; just would need refactoring.
|
| 155 |
+
return None, im_emb
|
|
|
|
| 156 |
|
| 157 |
+
output.images[0].save(path)
|
| 158 |
+
return path, im_emb
|
|
|
|
|
|
|
| 159 |
|
| 160 |
|
| 161 |
#######################
|
| 162 |
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
@spaces.GPU()
|
| 169 |
+
def solver(embs, ys):
|
| 170 |
+
print('ys:', ys,'EMBS:', embs.shape, embs)
|
| 171 |
+
ys = torch.tensor(ys).to('cpu', dtype=torch.float32).squeeze().unsqueeze(1)
|
| 172 |
+
|
| 173 |
+
sol = LinearSVC(class_weight='balanced').fit(np.array(embs), np.array(torch.tensor(ys).float() * 2 - 1)).coef_
|
| 174 |
+
return torch.tensor(sol).to('cpu', dtype=torch.float32)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
|
| 179 |
def get_user_emb(embs, ys):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
# sample only as many negatives as there are positives
|
| 181 |
+
indices = range(len(ys))
|
| 182 |
+
pos_indices = [i for i in indices if ys[i] > .5]
|
| 183 |
+
neg_indices = [i for i in indices if ys[i] <= .5]
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
+
mini = min(len(pos_indices), len(neg_indices))
|
| 186 |
+
|
| 187 |
+
if len(ys) > 20: # drop earliest of whichever of neg or pos is most abundant
|
| 188 |
+
if len(pos_indices) > len(neg_indices):
|
| 189 |
+
ind = pos_indices[0]
|
| 190 |
+
else:
|
| 191 |
+
ind = neg_indices[0]
|
| 192 |
+
ys.pop(ind)
|
| 193 |
+
embs.pop(ind)
|
| 194 |
+
print('Dropping at 20')
|
| 195 |
|
| 196 |
+
if mini < 1:
|
| 197 |
+
feature_embs = torch.stack([torch.randn(1280), torch.randn(1280)])
|
| 198 |
+
ys_t = [0, 1]
|
| 199 |
+
print('Not enough ratings.')
|
| 200 |
+
else:
|
| 201 |
+
indices = range(len(ys))
|
| 202 |
+
ys_t = [ys[i] for i in indices]
|
| 203 |
+
feature_embs = torch.stack([embs[e].detach().cpu() for e in indices]).squeeze()
|
| 204 |
+
|
| 205 |
+
# scaler = preprocessing.StandardScaler().fit(feature_embs)
|
| 206 |
+
# feature_embs = scaler.transform(feature_embs)
|
| 207 |
+
# ys_t = ys
|
| 208 |
+
|
| 209 |
+
print(np.array(feature_embs).shape, np.array(ys_t).shape)
|
| 210 |
|
| 211 |
+
sol = solver(feature_embs.squeeze(), ys_t)
|
| 212 |
+
dif = torch.tensor(sol, dtype=dtype).to(device)
|
| 213 |
|
| 214 |
+
# could j have a base vector of a black image
|
| 215 |
+
latest_pos = (random.sample([feature_embs[i] for i in range(len(ys_t)) if ys_t[i] > .5], 1)[0]).to(device, dtype)
|
| 216 |
+
|
| 217 |
+
dif = ((dif / dif.std()) * latest_pos.std())
|
|
|
|
| 218 |
|
| 219 |
+
sol = (1*latest_pos + 3*dif)/4
|
| 220 |
+
return sol
|
|
|
|
| 221 |
|
| 222 |
|
| 223 |
def pluck_img(user_id, user_emb):
|
| 224 |
not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
|
| 225 |
while len(not_rated_rows) == 0:
|
| 226 |
not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, 'gone') == 'gone' for i in prevs_df.iterrows()]]
|
| 227 |
+
time.sleep(.1)
|
| 228 |
# TODO optimize this lol
|
| 229 |
best_sim = -100000
|
| 230 |
for i in not_rated_rows.iterrows():
|
|
|
|
| 234 |
best_sim = sim
|
| 235 |
best_row = i[1]
|
| 236 |
img = best_row['paths']
|
| 237 |
+
return img
|
|
|
|
| 238 |
|
| 239 |
|
| 240 |
def background_next_image():
|
|
|
|
| 242 |
# only let it get N (maybe 3) ahead of the user
|
| 243 |
#not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
|
| 244 |
rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
|
| 245 |
+
if len(rated_rows) < 4:
|
| 246 |
+
time.sleep(.1)
|
| 247 |
# not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
|
| 248 |
+
return
|
|
|
|
| 249 |
|
| 250 |
user_id_list = set(rated_rows['latest_user_to_rate'].to_list())
|
| 251 |
for uid in user_id_list:
|
|
|
|
| 259 |
rated_from_user = rated_rows[[i[1]['from_user_id'] == uid for i in rated_rows.iterrows()]]
|
| 260 |
|
| 261 |
# we pop previous ratings if there are > n
|
| 262 |
+
if len(rated_from_user) >= 15:
|
| 263 |
oldest = rated_from_user.iloc[0]['paths']
|
| 264 |
prevs_df = prevs_df[prevs_df['paths'] != oldest]
|
| 265 |
# we don't compute more after n are in the queue for them
|
| 266 |
+
if len(unrated_from_user) >= 10:
|
| 267 |
continue
|
| 268 |
|
| 269 |
+
if len(rated_rows) < 5:
|
| 270 |
+
continue
|
| 271 |
+
|
| 272 |
+
embs, ys = pluck_embs_ys(uid)
|
| 273 |
+
|
| 274 |
+
user_emb = get_user_emb(embs, [y[1] for y in ys])
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
global glob_idx
|
| 278 |
+
glob_idx += 1
|
| 279 |
+
if glob_idx >= (len(prompt_list)-1):
|
| 280 |
+
glob_idx = 0
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
if glob_idx % 7 == 0:
|
| 284 |
+
text = prompt_list[glob_idx]
|
| 285 |
else:
|
| 286 |
+
text = 'an image'
|
| 287 |
+
img, embs = generate(user_emb, text)
|
| 288 |
|
| 289 |
if img:
|
| 290 |
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'text', 'gemb'])
|
|
|
|
| 293 |
tmp_df['user:rating'] = [{' ': ' '}]
|
| 294 |
tmp_df['from_user_id'] = [uid]
|
| 295 |
tmp_df['text'] = [text]
|
|
|
|
| 296 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
| 297 |
# we can free up storage by deleting the image
|
| 298 |
if len(prevs_df) > 500:
|
|
|
|
| 308 |
|
| 309 |
def pluck_embs_ys(user_id):
|
| 310 |
rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
|
| 311 |
+
#not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) == None for i in prevs_df.iterrows()]]
|
| 312 |
+
#while len(not_rated_rows) == 0:
|
| 313 |
+
# not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) == None for i in prevs_df.iterrows()]]
|
| 314 |
+
# rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
|
| 315 |
+
# time.sleep(.01)
|
| 316 |
+
# print('current user has 0 not_rated_rows')
|
| 317 |
|
| 318 |
embs = rated_rows['embeddings'].to_list()
|
| 319 |
ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
|
| 320 |
+
return embs, ys
|
|
|
|
| 321 |
|
| 322 |
def next_image(calibrate_prompts, user_id):
|
| 323 |
with torch.no_grad():
|
| 324 |
if len(calibrate_prompts) > 0:
|
| 325 |
cal_video = calibrate_prompts.pop(0)
|
| 326 |
image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
|
| 327 |
+
return image, calibrate_prompts,
|
| 328 |
else:
|
| 329 |
+
embs, ys = pluck_embs_ys(user_id)
|
| 330 |
+
ys_here = [y[1] for y in ys]
|
| 331 |
+
user_emb = get_user_emb(embs, ys_here)
|
| 332 |
+
image = pluck_img(user_id, user_emb)
|
| 333 |
+
return image, calibrate_prompts,
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
|
| 340 |
|
| 341 |
|
| 342 |
|
| 343 |
def start(_, calibrate_prompts, user_id, request: gr.Request):
|
| 344 |
user_id = int(str(time.time())[-7:].replace('.', ''))
|
| 345 |
+
image, calibrate_prompts = next_image(calibrate_prompts, user_id)
|
| 346 |
return [
|
| 347 |
+
gr.Button(value='π', interactive=True),
|
| 348 |
gr.Button(value='Neither (Space)', interactive=True, visible=False),
|
| 349 |
+
gr.Button(value='π', interactive=True),
|
| 350 |
gr.Button(value='Start', interactive=False),
|
| 351 |
+
gr.Button(value='π Content', interactive=True, visible=False),
|
| 352 |
+
gr.Button(value='π Style', interactive=True, visible=False),
|
| 353 |
image,
|
| 354 |
calibrate_prompts,
|
| 355 |
+
user_id,
|
| 356 |
+
|
| 357 |
]
|
| 358 |
|
| 359 |
|
|
|
|
| 361 |
global prevs_df
|
| 362 |
|
| 363 |
|
| 364 |
+
if choice == 'π':
|
| 365 |
+
choice = [1, 1]
|
| 366 |
elif choice == 'Neither (Space)':
|
| 367 |
+
img, calibrate_prompts, = next_image(calibrate_prompts, user_id)
|
| 368 |
+
return img, calibrate_prompts,
|
| 369 |
+
elif choice == 'π':
|
| 370 |
+
choice = [0, 0]
|
| 371 |
+
elif choice == 'π Style':
|
| 372 |
+
choice = [0, 1]
|
| 373 |
+
elif choice == 'π Content':
|
| 374 |
+
choice = [1, 0]
|
| 375 |
else:
|
| 376 |
+
assert False, f'choice is {choice}'
|
| 377 |
|
| 378 |
# if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
|
| 379 |
# TODO skip allowing rating & just continue
|
| 380 |
+
|
| 381 |
+
if img is None:
|
| 382 |
print('NSFW -- choice is disliked')
|
| 383 |
+
choice = [0, 0]
|
| 384 |
|
| 385 |
row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
|
| 386 |
# if it's still in the dataframe, add the choice
|
| 387 |
if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
|
| 388 |
prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
|
| 389 |
prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
|
| 390 |
+
img, calibrate_prompts, = next_image(calibrate_prompts, user_id)
|
| 391 |
+
return img, calibrate_prompts
|
| 392 |
|
| 393 |
css = '''.gradio-container{max-width: 700px !important}
|
| 394 |
#description{text-align: center}
|
|
|
|
| 451 |
user_id = gr.State()
|
| 452 |
# calibration videos -- this is a misnomer now :D
|
| 453 |
calibrate_prompts = gr.State([
|
| 454 |
+
'./first.png',
|
| 455 |
+
'./second.png',
|
| 456 |
+
'./sixth.png',
|
| 457 |
+
'./fifth.png',
|
| 458 |
+
'./fourth.png',
|
|
|
|
| 459 |
])
|
| 460 |
def l():
|
| 461 |
return None
|
| 462 |
|
| 463 |
with gr.Row(elem_id='output-image'):
|
| 464 |
+
img = gr.Image(
|
| 465 |
label='Lightning',
|
| 466 |
+
# autoplay=True,
|
| 467 |
interactive=False,
|
| 468 |
+
# height=512,
|
| 469 |
+
# width=512,
|
| 470 |
#include_audio=False,
|
| 471 |
+
elem_id="video_output",
|
| 472 |
+
type='filepath',
|
| 473 |
)
|
| 474 |
+
#img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
|
| 478 |
with gr.Row(equal_height=True):
|
| 479 |
+
b3 = gr.Button(value='π', interactive=False, elem_id="dislike")
|
| 480 |
+
|
| 481 |
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither", visible=False)
|
| 482 |
+
|
| 483 |
+
b1 = gr.Button(value='π', interactive=False, elem_id="like")
|
| 484 |
+
with gr.Row(equal_height=True):
|
| 485 |
+
b6 = gr.Button(value='π Style', interactive=False, elem_id="dislike like", visible=False)
|
| 486 |
+
|
| 487 |
+
b5 = gr.Button(value='π Content', interactive=False, elem_id="like dislike", visible=False)
|
| 488 |
+
|
| 489 |
b1.click(
|
| 490 |
choose,
|
| 491 |
[img, b1, calibrate_prompts, user_id],
|
| 492 |
+
[img, calibrate_prompts, ],
|
| 493 |
)
|
| 494 |
b2.click(
|
| 495 |
choose,
|
| 496 |
[img, b2, calibrate_prompts, user_id],
|
| 497 |
+
[img, calibrate_prompts, ],
|
| 498 |
)
|
| 499 |
b3.click(
|
| 500 |
choose,
|
| 501 |
[img, b3, calibrate_prompts, user_id],
|
| 502 |
+
[img, calibrate_prompts, ],
|
| 503 |
+
)
|
| 504 |
+
b5.click(
|
| 505 |
+
choose,
|
| 506 |
+
[img, b5, calibrate_prompts, user_id],
|
| 507 |
+
[img, calibrate_prompts, ],
|
| 508 |
+
)
|
| 509 |
+
b6.click(
|
| 510 |
+
choose,
|
| 511 |
+
[img, b6, calibrate_prompts, user_id],
|
| 512 |
+
[img, calibrate_prompts, ],
|
| 513 |
)
|
| 514 |
with gr.Row():
|
| 515 |
b4 = gr.Button(value='Start')
|
| 516 |
b4.click(start,
|
| 517 |
[b4, calibrate_prompts, user_id],
|
| 518 |
+
[b1, b2, b3, b4, b5, b6, img, calibrate_prompts, user_id, ]
|
| 519 |
)
|
| 520 |
with gr.Row():
|
| 521 |
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several videos and then roam. </ div><br><br><br>
|
|
|
|
| 526 |
</ div>''')
|
| 527 |
|
| 528 |
# TODO quiet logging
|
|
|
|
|
|
|
| 529 |
|
| 530 |
scheduler = BackgroundScheduler()
|
| 531 |
+
scheduler.add_job(func=background_next_image, trigger="interval", seconds=.2)
|
| 532 |
scheduler.start()
|
| 533 |
|
| 534 |
+
#thread = threading.Thread(target=background_next_image,)
|
| 535 |
+
#thread.start()
|
| 536 |
+
|
| 537 |
+
# TODO shouldn't call this before gradio launch, yeah?
|
| 538 |
+
@spaces.GPU()
|
| 539 |
+
def encode_space(x):
|
| 540 |
+
im_emb, _ = pipe.encode_image(
|
| 541 |
+
image, DEVICE, 1, output_hidden_state
|
| 542 |
+
)
|
| 543 |
+
return im_emb.detach().to('cpu').to(torch.float32)
|
| 544 |
|
| 545 |
# prep our calibration videos
|
| 546 |
+
for im, txt in [ # TODO more movement
|
| 547 |
+
('./first.png', 'describe the scene: a sketch'),
|
| 548 |
+
('./second.png', 'describe the scene: omens in the suburbs'),
|
| 549 |
+
('./sixth.png', 'describe the scene: geometric abstract art of a windmill'),
|
| 550 |
+
('./fifth.png', 'describe the scene: memento mori'),
|
| 551 |
+
('./fourth.png', 'describe the scene: a green plate with anespresso'),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 552 |
]:
|
| 553 |
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'text', 'gemb'])
|
| 554 |
tmp_df['paths'] = [im]
|
| 555 |
image = list(imageio.imiter(im))
|
| 556 |
image = image[len(image)//2]
|
| 557 |
+
im_emb = encode_space(image)
|
| 558 |
+
|
| 559 |
+
tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
|
| 560 |
tmp_df['user:rating'] = [{' ': ' '}]
|
| 561 |
+
tmp_df['text'] = [txt]
|
| 562 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
| 563 |
|
| 564 |
+
glob_idx = 0
|
| 565 |
+
demo.launch(share=True,)
|
| 566 |
|
| 567 |
|
twitter_prompts.csv
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
,0
|
| 2 |
+
0,a sunset
|
| 3 |
+
1,a still life in blue
|
| 4 |
+
2,last day on earth
|
| 5 |
+
3,the conch shell
|
| 6 |
+
4,the winds of change
|
| 7 |
+
5,a surrealist eye
|
| 8 |
+
6,a surrealist polaroid photo of an apple
|
| 9 |
+
7,metaphysics
|
| 10 |
+
8,the sun is setting into my glass of tea
|
| 11 |
+
9,the moon at 3am
|
| 12 |
+
10,a memento mori
|
| 13 |
+
11,quaking aspen tree
|
| 14 |
+
12,violets and daffodils
|
| 15 |
+
13,espresso
|
| 16 |
+
14,sisyphus
|
| 17 |
+
15,high windows of stained glass
|
| 18 |
+
16,a green dog
|
| 19 |
+
17,an adorable companion; it is a pig
|
| 20 |
+
18,bird of paradise
|
| 21 |
+
19,a complex intricate machine
|
| 22 |
+
20,a white clock
|
| 23 |
+
21,a film featuring the landscape Salt Lake City Utah
|
| 24 |
+
22,a creature
|
| 25 |
+
23,a house set aflame.
|
| 26 |
+
24,a gorgeous landscape by Cy Twombly
|
| 27 |
+
25,smoke rises from the caterpillar's hookah
|
| 28 |
+
26,corvid in red
|
| 29 |
+
27,Monet's pond
|
| 30 |
+
28,Genesis
|
| 31 |
+
29,Death is a black camel that kneels down so we can ride
|
| 32 |
+
30,a cherry tree made of fractals
|
| 33 |
+
29,the end of the sidewalk
|
| 34 |
+
30,a polaroid photo of a bustling city of lights and sky scrapers
|
| 35 |
+
31,The Fig Tree metaphor
|
| 36 |
+
32,God killed Van Gogh.
|
| 37 |
+
33,a cosmic entity alien with four eyes.
|
| 38 |
+
34,a horse with 128 eyes.
|
| 39 |
+
35,a being with an infinite set of eyes (it is omniscient)
|
| 40 |
+
36,A sticky-note magnum opus featuring birds
|
| 41 |
+
37,Moka Pot
|
| 42 |
+
38,the moon is a sickle cell
|
| 43 |
+
39,The Penultimate Supper
|
| 44 |
+
40,Art
|
| 45 |
+
41,surrealism
|
| 46 |
+
42,a god made of wires & dust
|
| 47 |
+
43,a dandelion blown into the universe
|