Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,6 +5,7 @@ import gradio as gr
|
|
| 5 |
import torchvision.transforms as T
|
| 6 |
import sys
|
| 7 |
import spaces
|
|
|
|
| 8 |
|
| 9 |
subprocess.run(["git", "clone", "https://github.com/AIRI-Institute/HairFastGAN"], check=True)
|
| 10 |
os.chdir("HairFastGAN")
|
|
@@ -34,20 +35,36 @@ from hair_swap import HairFast, get_parser
|
|
| 34 |
|
| 35 |
hair_fast = HairFast(get_parser().parse_args([]))
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
@spaces.GPU
|
| 38 |
def swap_hair(source, target_1, target_2, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
| 39 |
final_image = hair_fast.swap(source, target_1, target_2)
|
| 40 |
return T.functional.to_pil_image(final_image)
|
| 41 |
|
| 42 |
with gr.Blocks() as demo:
|
| 43 |
-
gr.Markdown("
|
| 44 |
with gr.Row():
|
| 45 |
source = gr.Image(label="Photo that you want to replace the hair", type="filepath")
|
| 46 |
target_1 = gr.Image(label="Reference hair you want to get", type="filepath")
|
| 47 |
target_2 = gr.Image(label="Reference color hair you want to get (optional)", type="filepath")
|
| 48 |
btn = gr.Button("Get the haircut")
|
| 49 |
output = gr.Image(label="Your result")
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
| 51 |
btn.click(fn=swap_hair, inputs=[source, target_1, target_2], outputs=[output])
|
| 52 |
|
| 53 |
demo.launch()
|
|
|
|
| 5 |
import torchvision.transforms as T
|
| 6 |
import sys
|
| 7 |
import spaces
|
| 8 |
+
from PIL import Image
|
| 9 |
|
| 10 |
subprocess.run(["git", "clone", "https://github.com/AIRI-Institute/HairFastGAN"], check=True)
|
| 11 |
os.chdir("HairFastGAN")
|
|
|
|
| 35 |
|
| 36 |
hair_fast = HairFast(get_parser().parse_args([]))
|
| 37 |
|
| 38 |
+
def resize(image_path):
|
| 39 |
+
img = Image.open("image_path")
|
| 40 |
+
square_size = 1024
|
| 41 |
+
|
| 42 |
+
left = (img.width - square_size) / 2
|
| 43 |
+
top = (img.height - square_size) / 2
|
| 44 |
+
right = (img.width + square_size) / 2
|
| 45 |
+
bottom = (img.height + square_size) / 2
|
| 46 |
+
|
| 47 |
+
img_cropped = img.crop((left, top, right, bottom))
|
| 48 |
+
return img_cropped
|
| 49 |
+
|
| 50 |
@spaces.GPU
|
| 51 |
def swap_hair(source, target_1, target_2, progress=gr.Progress(track_tqdm=True)):
|
| 52 |
+
target_2 = target_2 if target_2 else target_1
|
| 53 |
final_image = hair_fast.swap(source, target_1, target_2)
|
| 54 |
return T.functional.to_pil_image(final_image)
|
| 55 |
|
| 56 |
with gr.Blocks() as demo:
|
| 57 |
+
gr.Markdown("## HairFastGan")
|
| 58 |
with gr.Row():
|
| 59 |
source = gr.Image(label="Photo that you want to replace the hair", type="filepath")
|
| 60 |
target_1 = gr.Image(label="Reference hair you want to get", type="filepath")
|
| 61 |
target_2 = gr.Image(label="Reference color hair you want to get (optional)", type="filepath")
|
| 62 |
btn = gr.Button("Get the haircut")
|
| 63 |
output = gr.Image(label="Your result")
|
| 64 |
+
gr.Examples(examples=[("michael_cera-min.png", "leo_square-min.png", "pink_hair_celeb-min.png")])
|
| 65 |
+
source.upload(fn=resize, input=source, output=source)
|
| 66 |
+
target_1.upload(fn=resize, input=target_1, output=target_1)
|
| 67 |
+
target_2.upload(fn=resize, input=target_2, output=target_2)
|
| 68 |
btn.click(fn=swap_hair, inputs=[source, target_1, target_2], outputs=[output])
|
| 69 |
|
| 70 |
demo.launch()
|