Spaces:
Runtime error
Runtime error
updated code; added theme
Browse files
app.py
CHANGED
|
@@ -24,24 +24,26 @@ def get_IF_op(prompt, neg_prompt):
|
|
| 24 |
file_list = os.listdir(folder_path)
|
| 25 |
file_list = [os.path.join(folder_path, f) for f in file_list if f != 'captions.json']
|
| 26 |
print(f"^^file list is: {file_list}")
|
| 27 |
-
return file_list
|
| 28 |
|
| 29 |
-
def get_pickscores(prompt,
|
| 30 |
print("inside get_pickscores")
|
| 31 |
#Get the predictons
|
| 32 |
-
probabilities1 = client_pick.predict(prompt,
|
| 33 |
-
probabilities2 = client_pick.predict(prompt,
|
| 34 |
probabilities_all = list(probabilities1) + list(probabilities2)
|
| 35 |
max_score = max(probabilities_all)
|
| 36 |
max_score_index = probabilities_all.index(max_score)
|
| 37 |
-
best_match_image =
|
| 38 |
return best_match_image
|
| 39 |
|
|
|
|
| 40 |
def get_upscale_op(prompt, gallery_if):
|
| 41 |
print("inside get_upscale_op")
|
| 42 |
print(f"^^gallery_if is: {gallery_if}")
|
|
|
|
| 43 |
# get pickscores
|
| 44 |
-
best_match_image = get_pickscores(prompt,
|
| 45 |
# let's get the best pick!
|
| 46 |
low_res_img = Image.open(best_match_image).convert("RGB")
|
| 47 |
low_res_img = low_res_img.resize((128, 128))
|
|
@@ -50,18 +52,20 @@ def get_upscale_op(prompt, gallery_if):
|
|
| 50 |
#upscaled_image.save("upsampled.png")
|
| 51 |
return upscaled_image
|
| 52 |
|
| 53 |
-
with gr.Blocks() as demo:
|
| 54 |
with gr.Row():
|
| 55 |
-
with gr.Column():
|
| 56 |
prompt = gr.Textbox(label='Prompt')
|
| 57 |
neg_prompt = gr.Textbox(label='Negative Prompt')
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
b2 = gr.Button("Get the best generation using Pick-A-Pic")
|
| 62 |
image_picakapic = gr.Image(label="PickAPic Evaluated Output")
|
| 63 |
|
| 64 |
b1.click(get_IF_op,[prompt, neg_prompt], gallery_if)
|
|
|
|
| 65 |
b2.click(get_upscale_op,[prompt, gallery_if], image_picakapic)
|
| 66 |
|
| 67 |
-
|
|
|
|
|
|
| 24 |
file_list = os.listdir(folder_path)
|
| 25 |
file_list = [os.path.join(folder_path, f) for f in file_list if f != 'captions.json']
|
| 26 |
print(f"^^file list is: {file_list}")
|
| 27 |
+
return file_list
|
| 28 |
|
| 29 |
+
def get_pickscores(prompt, image_tmps):
|
| 30 |
print("inside get_pickscores")
|
| 31 |
#Get the predictons
|
| 32 |
+
probabilities1 = client_pick.predict(prompt, image_tmps[0], image_tmps[1], fn_index=0)
|
| 33 |
+
probabilities2 = client_pick.predict(prompt, image_tmps[2], image_tmps[3], fn_index=0)
|
| 34 |
probabilities_all = list(probabilities1) + list(probabilities2)
|
| 35 |
max_score = max(probabilities_all)
|
| 36 |
max_score_index = probabilities_all.index(max_score)
|
| 37 |
+
best_match_image = image_tmps[max_score_index]
|
| 38 |
return best_match_image
|
| 39 |
|
| 40 |
+
|
| 41 |
def get_upscale_op(prompt, gallery_if):
|
| 42 |
print("inside get_upscale_op")
|
| 43 |
print(f"^^gallery_if is: {gallery_if}")
|
| 44 |
+
image_tmps = [val['name'] for val in gallery_if]
|
| 45 |
# get pickscores
|
| 46 |
+
best_match_image = get_pickscores(prompt, image_tmps)
|
| 47 |
# let's get the best pick!
|
| 48 |
low_res_img = Image.open(best_match_image).convert("RGB")
|
| 49 |
low_res_img = low_res_img.resize((128, 128))
|
|
|
|
| 52 |
#upscaled_image.save("upsampled.png")
|
| 53 |
return upscaled_image
|
| 54 |
|
| 55 |
+
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
| 56 |
with gr.Row():
|
| 57 |
+
with gr.Column(scale=4):
|
| 58 |
prompt = gr.Textbox(label='Prompt')
|
| 59 |
neg_prompt = gr.Textbox(label='Negative Prompt')
|
| 60 |
+
with gr.Column(scale=1):
|
| 61 |
+
b1 = gr.Button('Generate').style(full_width=True)
|
| 62 |
+
gallery_if = gr.Gallery(label='IF Space outputs', ).style(columns=4, object_fit="contain", preview=True)
|
| 63 |
b2 = gr.Button("Get the best generation using Pick-A-Pic")
|
| 64 |
image_picakapic = gr.Image(label="PickAPic Evaluated Output")
|
| 65 |
|
| 66 |
b1.click(get_IF_op,[prompt, neg_prompt], gallery_if)
|
| 67 |
+
prompt.submit(get_IF_op,[prompt, neg_prompt], gallery_if)
|
| 68 |
b2.click(get_upscale_op,[prompt, gallery_if], image_picakapic)
|
| 69 |
|
| 70 |
+
|
| 71 |
+
demo.launch()
|