Spaces:
Running
Running
| import base64 | |
| import logging | |
| import os | |
| import time | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import requests | |
| from gradio.themes.utils import sizes | |
| # LOGGING | |
| logger = logging.getLogger("TRYON") | |
| logger.setLevel(logging.INFO) | |
| handler = logging.StreamHandler() | |
| formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S") | |
| handler.setFormatter(formatter) | |
| logger.addHandler(handler) | |
| # IMAGE ASSETS | |
| ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets") | |
| # API CONFIG | |
| #FASHN_ENDPOINT_URL = os.environ.get("FASHN_ENDPOINT_URL", "https://api.fashn.ai/v1") | |
| FASHN_ENDPOINT_URL = "https://api.fashn.ai/v1" | |
| #FASHN_API_KEY = os.environ.get("FASHN_API_KEY") | |
| FASHN_API_KEY = "fa-bXvHG3Z8zBBM-cUJuLvRFrFi00BD35ZIis5t7" | |
| assert FASHN_ENDPOINT_URL, "Please set the FASHN_ENDPOINT_URL environment variable" | |
| assert FASHN_API_KEY, "Please set the FASHN_API_KEY environment variable" | |
| # ----------------- HELPER FUNCTIONS ----------------- # | |
| CATEGORY_API_MAPPING = {"Top": "tops", "Bottom": "bottoms", "Full-body": "one-pieces"} | |
| def opencv_load_image_from_http(url: str) -> np.ndarray: | |
| """Loads an image from a given URL using HTTP GET.""" | |
| with requests.get(url) as response: | |
| response.raise_for_status() | |
| image_data = np.frombuffer(response.content, np.uint8) | |
| image = cv2.imdecode(image_data, cv2.IMREAD_COLOR) | |
| return image | |
| def encode_img_to_base64(img: np.array) -> str: | |
| """Encodes an image as a JPEG in Base64 format.""" | |
| img = cv2.imencode(".jpg", img)[1].tobytes() | |
| img = base64.b64encode(img).decode("utf-8") | |
| img = f"data:image/jpeg;base64,{img}" | |
| return img | |
| def parse_checkboxes(checkboxes): | |
| checkboxes = [checkbox.lower().replace(" ", "_") for checkbox in checkboxes] | |
| checkboxes = {checkbox: True for checkbox in checkboxes} | |
| return checkboxes | |
| def make_api_request(session, url, headers, data=None, method="GET", max_retries=3, timeout=60): | |
| for attempt in range(max_retries): | |
| try: | |
| if method.upper() == "GET": | |
| response = session.get(url, headers=headers, timeout=timeout) | |
| elif method.upper() == "POST": | |
| response = session.post(url, headers=headers, json=data, timeout=timeout) | |
| else: | |
| raise ValueError(f"Unsupported HTTP method: {method}") | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| if attempt == max_retries - 1: # If it's the last attempt | |
| raise Exception(f"API call failed after {max_retries} attempts: {str(e)}") from e | |
| print(f"Attempt {attempt + 1} failed. Retrying...") | |
| time.sleep(2) # Wait for 2 seconds before retrying | |
| # ----------------- CORE FUNCTION ----------------- # | |
| def get_tryon_result( | |
| model_image, | |
| garment_image, | |
| garment_photo_type, | |
| category, | |
| nsfw_filter, | |
| cover_feet, | |
| adjust_hands, | |
| restore_background, | |
| restore_clothes, | |
| guidance_scale, | |
| timesteps, | |
| seed, | |
| num_samples, | |
| ): | |
| logger.info("Starting new try-on request...") | |
| # preprocessing: convert to RGB, resize, encode to base64 | |
| model_image, garment_image = map(lambda x: cv2.cvtColor(x, cv2.COLOR_RGB2BGR), [model_image, garment_image]) | |
| model_image, garment_image = map(encode_img_to_base64, [model_image, garment_image]) | |
| # prepare data for API request | |
| data = { | |
| "model_image": model_image, | |
| "garment_image": garment_image, | |
| "garment_photo_type": garment_photo_type.lower(), | |
| "category": CATEGORY_API_MAPPING[category], | |
| "nsfw_filter": nsfw_filter, | |
| "cover_feet": cover_feet, | |
| "adjust_hands": adjust_hands, | |
| "restore_background": restore_background, | |
| "restore_clothes": restore_clothes, | |
| "guidance_scale": guidance_scale, | |
| "timesteps": timesteps, | |
| "seed": seed, | |
| "num_samples": num_samples, | |
| } | |
| # make API request | |
| session = requests.Session() | |
| headers = {"Content-Type": "application/json", "Authorization": f"Bearer {FASHN_API_KEY}"} | |
| try: | |
| response_data = make_api_request( | |
| session, f"{FASHN_ENDPOINT_URL}/run", headers=headers, data=data, method="POST" | |
| ) | |
| pred_id = response_data.get("id") | |
| logger.info(f"Prediction ID: {pred_id}") | |
| except Exception as e: | |
| raise gr.Error(f"Status check failed: {str(e)}") | |
| # poll the status of the prediction | |
| start_time = time.time() | |
| while True: | |
| if time.time() - start_time > 180: # 3 minutes timeout | |
| raise gr.Error("Maximum polling time exceeded.") | |
| try: | |
| status_data = make_api_request( | |
| session, f"{FASHN_ENDPOINT_URL}/status/{pred_id}", headers=headers, method="GET" | |
| ) | |
| except Exception as e: | |
| raise gr.Error(f"Status check failed: {str(e)}") | |
| if status_data["status"] == "completed": | |
| logger.info("Prediction completed.") | |
| break | |
| elif status_data["status"] not in ["starting", "in_queue", "processing"]: | |
| raise gr.Error(f"Prediction failed with id {pred_id}: {status_data.get('error')}") | |
| logger.info(f"Prediction status: {status_data['status']}") | |
| time.sleep(3) | |
| # get the result images | |
| result_imgs = [] | |
| for output_url in status_data["output"]: | |
| result_img = opencv_load_image_from_http(output_url) | |
| result_img = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB) | |
| result_imgs.append(result_img) | |
| return result_imgs | |
| # ----------------- GRADIO UI ----------------- # | |
| with open("banner.html", "r") as file: | |
| banner = file.read() | |
| with open("tips.html", "r") as file: | |
| tips = file.read() | |
| with open("footer.html", "r") as file: | |
| footer = file.read() | |
| CUSTOM_CSS = """ | |
| .image-container img { | |
| max-width: 384px; | |
| max-height: 576px; | |
| margin: 0 auto; | |
| border-radius: 0px; | |
| .gradio-container {background-color: #fafafa} | |
| """ | |
| with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Monochrome(radius_size=sizes.radius_md)) as demo: | |
| gr.HTML(banner) | |
| gr.HTML(tips) | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_image = gr.Image(label="Foto Model", type="numpy") | |
| with gr.Accordion("Model Image Controls", open=False): | |
| cover_feet = gr.Checkbox(label="Cover Feet", value=False) | |
| adjust_hands = gr.Checkbox(label="Adjust Hands", value=False) | |
| restore_background = gr.Checkbox(label="Restore Background", value=False) | |
| restore_clothes = gr.Checkbox(label="Restore Clothes", value=False) | |
| nsfw_filter = gr.Checkbox(label="NSFW Filter", value=True) | |
| example_model = gr.Examples(label="Pilih model", | |
| inputs=model_image, | |
| examples_per_page=10, | |
| examples=[ | |
| os.path.join(ASSETS_DIR, "models", img) for img in os.listdir(os.path.join(ASSETS_DIR, "models")) | |
| ], | |
| ) | |
| with gr.Column(): | |
| garment_image = gr.Image(label="Produk", type="numpy") | |
| garment_photo_type = gr.Radio( | |
| choices=["Auto", "Flat-Lay", "Model"], label="Select Photo Type", value="Auto" | |
| ) | |
| category = gr.Radio(choices=["Top", "Bottom", "Full-body"], label="Select Category", value="Top") | |
| example_garment = gr.Examples(label="Pilih produk", | |
| inputs=garment_image, | |
| examples_per_page=10, | |
| examples=[ | |
| os.path.join(ASSETS_DIR, "garments", img) | |
| for img in os.listdir(os.path.join(ASSETS_DIR, "garments")) | |
| ], | |
| ) | |
| with gr.Column(): | |
| result_gallery = gr.Gallery(label="Hasil", show_label=True, elem_id="gallery") | |
| run_button = gr.Button("Coba") | |
| with gr.Accordion("Sampling Controls", open=False): | |
| guidance_scale = gr.Slider(minimum=1.5, maximum=3, value=2.0, step=0.1, label="Guidance Scale") | |
| timesteps = gr.Slider(minimum=10, maximum=50, step=1, value=50, label="Timesteps") | |
| seed = gr.Number(label="Seed", value=42, precision=0) | |
| num_samples = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="Number of Samples") | |
| run_button.click( | |
| fn=get_tryon_result, | |
| inputs=[ | |
| model_image, | |
| garment_image, | |
| garment_photo_type, | |
| category, | |
| nsfw_filter, | |
| cover_feet, | |
| adjust_hands, | |
| restore_background, | |
| restore_clothes, | |
| guidance_scale, | |
| timesteps, | |
| seed, | |
| num_samples, | |
| ], | |
| outputs=[result_gallery], | |
| ) | |
| gr.HTML(footer) | |
| if __name__ == "__main__": | |
| demo.launch() | |