Spaces:
Paused
Paused
| import gradio as gr | |
| import numpy as np | |
| import time | |
| import requests | |
| import json | |
| import os | |
| import random | |
| import tempfile | |
| import logging | |
| import threading | |
| import asyncio | |
| from PIL import Image | |
| from io import BytesIO | |
| from requests.adapters import HTTPAdapter | |
| from urllib3.util import Retry | |
| odnapi = os.getenv("odnapi_url") | |
| fetapi = os.getenv("fetapi_url") | |
| auth_token = os.getenv("auth_token") | |
| # Setup a logger | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def split_image(img): | |
| width, height = img.size | |
| width_cut = width // 2 | |
| height_cut = height // 2 | |
| return [ | |
| img.crop((0, 0, width_cut, height_cut)), | |
| img.crop((width_cut, 0, width, height_cut)), | |
| img.crop((0, height_cut, width_cut, height)), | |
| img.crop((width_cut, height_cut, width, height)) | |
| ] | |
| def save_image(img, suffix='.png'): | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
| img.save(tmp, 'PNG') | |
| return tmp.name | |
| async def niji_api(prompt, progress=gr.Progress(), max_retries=5, backoff_factor=0.1): | |
| iters = 1 | |
| progress(iters/32, desc="Sending request to MidJourney Server") | |
| session = requests.Session() # Using Session to reuse the underlying TCP connection | |
| retries = Retry(total=max_retries, backoff_factor=backoff_factor, status_forcelist=[429, 500, 502, 503, 504]) | |
| adapter = HTTPAdapter(max_retries=retries) | |
| session.mount("http://", adapter) | |
| session.mount("https://", adapter) | |
| try: | |
| response = session.post( | |
| fetapi, | |
| headers={'Content-Type': 'application/json'}, | |
| data=json.dumps({'msg': prompt}), | |
| timeout=5.0 # Here, the timeout duration is set to 5 seconds | |
| ) | |
| response.raise_for_status() # Check for HTTP errors. | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Failed to make POST request") | |
| raise ValueError("Invalid Response") | |
| data = response.json() | |
| message_id = data['messageId'] | |
| prog = 0 | |
| iters += 5 | |
| progress(iters/48, desc="Waiting in the generate queue") | |
| def fetch_image(url): | |
| try: | |
| response = session.get(url, timeout=5.0) | |
| return Image.open(BytesIO(response.content)) | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Failed to fetch image") | |
| return None | |
| def download_and_split_image(url): | |
| try: | |
| img = fetch_image(url) | |
| images = split_image(img) | |
| return [save_image(i) for i in images] | |
| except Exception: | |
| pass | |
| while prog < 100: | |
| try: | |
| response = session.get( | |
| f'{odnapi}/message/{message_id}?expireMins=2', | |
| headers={'Authorization': auth_token}, | |
| timeout=5.0 | |
| ) | |
| response.raise_for_status() | |
| except requests.exceptions.RequestException as e: | |
| logger.warning(f"Failure in getting message response") | |
| continue | |
| data = response.json() | |
| prog = data.get('progress', 0) | |
| if progress_image_url := data.get('progressImageUrl'): | |
| iters = -100 | |
| yield [(img, f"{prog}% done") for img in download_and_split_image(progress_image_url)] | |
| wait_time = random.uniform(1, 2) | |
| await asyncio.sleep(wait_time) | |
| r = iters/48 | |
| if r < 0.4: | |
| desc = "Waiting in the generate queue" | |
| elif r < 0.6: | |
| desc = "Still queueing" | |
| elif r < 0.8: | |
| desc = "Almost done" | |
| if iters > 0: | |
| progress(r, desc=desc) | |
| iters += random.uniform(1, 2) | |
| # Process the final image urls | |
| image_url = data['response']['imageUrl'] | |
| yield [(img, f"image {idx+1}/4") for idx, img in enumerate(download_and_split_image(image_url))] | |
| with gr.Blocks() as demo: | |
| gr.HTML(''' | |
| <div style="text-align: center; max-width: 650px; margin: 0 auto;"> | |
| <div style=" | |
| display: inline-flex; | |
| gap: 0.8rem; | |
| font-size: 1.75rem; | |
| justify-content: center; | |
| margin-bottom: 10px; | |
| "> | |
| <h1 style="font-weight: 900; align-items: center; margin-bottom: 7px; margin-top: 20px;"> | |
| MidJourney / NijiJourney Playground 🎨 | |
| </h1> | |
| </div> | |
| <div> | |
| <p style="align-items: center; margin-bottom: 7px;"> | |
| Demo for the <a href="https://MidJourney.com/" target="_blank">MidJourney</a>, draw with heart and love. | |
| </div> | |
| </div> | |
| ''') | |
| with gr.Column(variant="panel"): | |
| with gr.Row(): | |
| text = gr.Textbox( | |
| label="Enter your prompt", | |
| value="1girl,long hair,looking at viewer,kawaii,serafuku --s 250 --niji 5", | |
| max_lines=3, | |
| container=False, | |
| ) | |
| btn = gr.Button("Generate image", scale=0) | |
| gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", height="4096") | |
| btn.click(niji_api, text, gallery) | |
| demo.queue(concurrency_count=2) | |
| demo.launch(debug=True) |