ysharma HF Staff commited on
Commit
71eef5e
·
1 Parent(s): 41df49f

create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import requests
2
+ import gradio as gr
3
+ from gradio_client import Client
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ from diffusers import StableDiffusionUpscalePipeline
7
+ import torch
8
+ import os
9
+
10
+ HF_TOKEN = os.environ.get('HF_TOKEN')
11
+ client_if = Client("huggingface-projects/IF", hf_token=HF_TOKEN)
12
+ client_pick = Client("yuvalkirstain/PickScore")
13
+
14
+ # load upscaling model and scheduler
15
+ model_id = "stabilityai/stable-diffusion-x4-upscaler"
16
+ pipeline_upscale = StableDiffusionUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
17
+ pipeline_upscale = pipeline_upscale.to("cuda")
18
+
19
+ def get_IF_op(prompt, neg_prompt):
20
+ filepaths = client_if.predict(prompt, neg_prompt, 1,4,7.0, 'smart100',50, api_name="/generate64")
21
+ folder_path = filepaths[0]
22
+ file_list = os.listdir(folder_path)
23
+ file_list = [os.path.join(folder_path, f) for f in file_list if f != 'captions.json']
24
+ return file_list
25
+
26
+ def get_pickscores(prompt, file_list):
27
+ #Get the predictons
28
+ probabilities1 = client_pick.predict(prompt, file_list[0], file_list[1], fn_index=0)
29
+ probabilities2 = client_pick.predict(prompt, file_list[2], file_list[3], fn_index=0)
30
+ probabilities_all = list(probabilities1) + list(probabilities2)
31
+ max_score = max(probabilities_all)
32
+ max_score_index = probabilities_all.index(max_score)
33
+ best_match_image = file_list[max_score_index]
34
+ return best_match_image
35
+
36
+ def get_upscale_op(prompt, best_match_image):
37
+ # let's get the image
38
+ low_res_img = Image.open(best_match_image).convert("RGB")
39
+ low_res_img = low_res_img.resize((128, 128))
40
+ upscaled_image = pipeline_upscale(prompt=prompt, image=low_res_img).images[0]
41
+ return upscaled_image
42
+