Spaces:
Running
Running
| import subprocess | |
| import shlex | |
| # Install the custom component if needed | |
| subprocess.run( | |
| shlex.split( | |
| "pip install ./gradio_magicquillv2-0.0.1-py3-none-any.whl" | |
| ) | |
| ) | |
| import sys | |
| import os | |
| import gradio as gr | |
| import tempfile | |
| import numpy as np | |
| import io | |
| import base64 | |
| import json | |
| import uvicorn | |
| import torch | |
| from fastapi import FastAPI, Request | |
| from fastapi.concurrency import run_in_threadpool | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from gradio_client import Client, handle_file | |
| from gradio_magicquillv2 import MagicQuillV2 | |
| from PIL import Image | |
| from util import ( | |
| read_base64_image as read_base64_image_utils, | |
| tensor_to_base64, | |
| get_mask_bbox | |
| ) | |
| # --- Configuration --- | |
| # Set this to the URL of your backend Space (running app_backend.py) | |
| # Example: "https://huggingface.co/spaces/username/backend-space" | |
| hf_token = os.environ.get("HF_TOKEN") | |
| BACKEND_URL = "LiuZichen/MagicQuillV2" | |
| SAM_URL = "LiuZichen/MagicQuillHelper" | |
| print(f"Connecting to backend at: {BACKEND_URL}") | |
| try: | |
| backend_client = Client(BACKEND_URL, hf_token=hf_token) | |
| except Exception as e: | |
| print(f"Failed to connect to backend: {e}") | |
| backend_client = None | |
| print(f"Connecting to SAM client at: {SAM_URL}") | |
| try: | |
| sam_client = Client(SAM_URL, hf_token=hf_token) | |
| except Exception as e: | |
| print(f"Failed to connect to SAM client: {e}") | |
| sam_client = None | |
| # --- Helper Functions --- | |
| def generate_image_handler(x, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg): | |
| merged_image = x['from_frontend']['img'] | |
| total_mask = x['from_frontend']['total_mask'] | |
| original_image = x['from_frontend']['original_image'] | |
| add_color_image = x['from_frontend']['add_color_image'] | |
| add_edge_mask = x['from_frontend']['add_edge_mask'] | |
| remove_edge_mask = x['from_frontend']['remove_edge_mask'] | |
| fill_mask = x['from_frontend']['fill_mask'] | |
| add_prop_image = x['from_frontend']['add_prop_image'] | |
| positive_prompt = x['from_backend']['prompt'] | |
| if backend_client is None: | |
| print("Backend client not initialized") | |
| x["from_backend"]["generated_image"] = None | |
| return x | |
| try: | |
| # Call the backend API | |
| # The order of arguments must match app_backend.py input list | |
| res_base64 = backend_client.predict( | |
| merged_image, # merged_image | |
| total_mask, # total_mask | |
| original_image, # original_image | |
| add_color_image, # add_color_image | |
| add_edge_mask, # add_edge_mask | |
| remove_edge_mask, # remove_edge_mask | |
| fill_mask, # fill_mask | |
| add_prop_image, # add_prop_image | |
| positive_prompt, # positive_prompt | |
| negative_prompt, # negative_prompt | |
| fine_edge, # fine_edge | |
| fix_perspective, # fix_perspective | |
| grow_size, # grow_size | |
| edge_strength, # edge_strength | |
| color_strength, # color_strength | |
| local_strength, # local_strength | |
| seed, # seed | |
| steps, # steps | |
| cfg, # cfg | |
| api_name="/generate" | |
| ) | |
| x["from_backend"]["generated_image"] = res_base64 | |
| except Exception as e: | |
| print(f"Error in generation: {e}") | |
| x["from_backend"]["generated_image"] = None | |
| return x | |
| # --- Gradio UI --- | |
| with gr.Blocks(title="MagicQuill V2") as demo: | |
| with gr.Row(elem_classes="row"): | |
| text = gr.Markdown( | |
| """ | |
| # Welcome to MagicQuill V2! Give us a [GitHub star](https://github.com/zliucz/magicquillv2) if you are interested. | |
| Click the [link](https://magicquill.art/v2) to view our demo and tutorial. The paper is on [ArXiv](https://arxiv.org/abs/2512.03046) now. | |
| """) | |
| with gr.Row(): | |
| ms = MagicQuillV2() | |
| with gr.Row(): | |
| with gr.Column(): | |
| btn = gr.Button("Run", variant="primary") | |
| with gr.Column(): | |
| with gr.Accordion("parameters", open=False): | |
| negative_prompt = gr.Textbox(label="Negative Prompt", value="", interactive=True) | |
| fine_edge = gr.Radio(label="Fine Edge", choices=['enable', 'disable'], value='disable', interactive=True) | |
| fix_perspective = gr.Radio(label="Fix Perspective", choices=['enable', 'disable'], value='disable', interactive=True) | |
| grow_size = gr.Slider(label="Grow Size", minimum=10, maximum=100, value=50, step=1, interactive=True) | |
| edge_strength = gr.Slider(label="Edge Strength", minimum=0.0, maximum=5.0, value=0.6, step=0.01, interactive=True) | |
| color_strength = gr.Slider(label="Color Strength", minimum=0.0, maximum=5.0, value=1.5, step=0.01, interactive=True) | |
| local_strength = gr.Slider(label="Local Strength", minimum=0.0, maximum=5.0, value=1.0, step=0.01, interactive=True) | |
| seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True) | |
| steps = gr.Slider(label="Steps", minimum=0, maximum=50, value=20, interactive=True) | |
| cfg = gr.Slider(label="CFG", minimum=0.0, maximum=20.0, value=3.5, step=0.1, interactive=True) | |
| btn.click( | |
| generate_image_handler, | |
| inputs=[ms, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg], | |
| outputs=ms | |
| ) | |
| # --- FastAPI App --- | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=['*'], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def get_root_url(request: Request, route_path: str, root_path: str | None): | |
| return root_path | |
| gr.route_utils.get_root_url = get_root_url | |
| # gr.mount_gradio_app(app, demo, path="/demo", root_path="/demo") | |
| async def generate_image(request: Request): | |
| data = await request.json() | |
| if backend_client is None: | |
| return {'error': 'Backend client not connected'} | |
| try: | |
| res = await run_in_threadpool( | |
| backend_client.predict, | |
| data["merged_image"], | |
| data["total_mask"], | |
| data["original_image"], | |
| data["add_color_image"], | |
| data["add_edge_mask"], | |
| data["remove_edge_mask"], | |
| data["fill_mask"], | |
| data["add_prop_image"], | |
| data["positive_prompt"], | |
| data["negative_prompt"], | |
| data["fine_edge"], | |
| data["fix_perspective"], | |
| data["grow_size"], | |
| data["edge_strength"], | |
| data["color_strength"], | |
| data["local_strength"], | |
| data["seed"], | |
| data["steps"], | |
| data["cfg"], | |
| api_name="/generate" | |
| ) | |
| return {'res': res} | |
| except Exception as e: | |
| print(f"Error in backend generation: {e}") | |
| return {'error': str(e)} | |
| async def process_background_img(request: Request): | |
| img = await request.json() | |
| from util import process_background | |
| # process_background returns tensor [1, H, W, 3] in uint8 or float | |
| resized_img_tensor = process_background(img) | |
| # tensor_to_base64 from util expects tensor | |
| resized_img_base64 = "data:image/webp;base64," + tensor_to_base64( | |
| resized_img_tensor, | |
| quality=80, | |
| method=6 | |
| ) | |
| return resized_img_base64 | |
| async def segmentation(request: Request): | |
| json_data = await request.json() | |
| image_base64 = json_data.get("image", None) | |
| coordinates_positive = json_data.get("coordinates_positive", None) | |
| coordinates_negative = json_data.get("coordinates_negative", None) | |
| bboxes = json_data.get("bboxes", None) | |
| if sam_client is None: | |
| return {"error": "sam client not initialized"} | |
| # Process coordinates and bboxes (copied from original app.py) | |
| pos_coordinates = None | |
| if coordinates_positive and len(coordinates_positive) > 0: | |
| pos_coordinates = [] | |
| for coord in coordinates_positive: | |
| coord['x'] = int(round(coord['x'])) | |
| coord['y'] = int(round(coord['y'])) | |
| pos_coordinates.append({'x': coord['x'], 'y': coord['y']}) | |
| pos_coordinates = json.dumps(pos_coordinates) | |
| neg_coordinates = None | |
| if coordinates_negative and len(coordinates_negative) > 0: | |
| neg_coordinates = [] | |
| for coord in coordinates_negative: | |
| coord['x'] = int(round(coord['x'])) | |
| coord['y'] = int(round(coord['y'])) | |
| neg_coordinates.append({'x': coord['x'], 'y': coord['y']}) | |
| neg_coordinates = json.dumps(neg_coordinates) | |
| bboxes_xyxy = None | |
| if bboxes and len(bboxes) > 0: | |
| valid_bboxes = [] | |
| for bbox in bboxes: | |
| if (bbox.get("startX") is None or | |
| bbox.get("startY") is None or | |
| bbox.get("endX") is None or | |
| bbox.get("endY") is None): | |
| continue | |
| else: | |
| x_min = max(min(int(bbox["startX"]), int(bbox["endX"])), 0) | |
| y_min = max(min(int(bbox["startY"]), int(bbox["endY"])), 0) | |
| x_max = int(bbox["startX"]) if int(bbox["startX"]) > int(bbox["endX"]) else int(bbox["endX"]) | |
| y_max = int(bbox["startY"]) if int(bbox["startY"]) > int(bbox["endY"]) else int(bbox["endY"]) | |
| valid_bboxes.append((x_min, y_min, x_max, y_max)) | |
| bboxes_xyxy = [] | |
| for bbox in valid_bboxes: | |
| x_min, y_min, x_max, y_max = bbox | |
| bboxes_xyxy.append((x_min, y_min, x_max, y_max)) | |
| if bboxes_xyxy: | |
| bboxes_xyxy = json.dumps(bboxes_xyxy) | |
| print(f"Segmentation request: pos={pos_coordinates}, neg={neg_coordinates}, bboxes={bboxes_xyxy}") | |
| try: | |
| # Save base64 image to temp file | |
| image_bytes = read_base64_image_utils(image_base64) | |
| pil_image = Image.open(image_bytes) | |
| # Resize for faster transmission (short side 512) | |
| original_size = pil_image.size | |
| w, h = original_size | |
| scale = 512 / min(w, h) | |
| if scale < 1: | |
| new_w = int(w * scale) | |
| new_h = int(h * scale) | |
| pil_image_resized = pil_image.resize((new_w, new_h), Image.LANCZOS) | |
| print(f"Resized image for segmentation: {original_size} -> {(new_w, new_h)}") | |
| # Adjust coordinates and bboxes according to scale | |
| if pos_coordinates: | |
| pos_coords_list = json.loads(pos_coordinates) | |
| for coord in pos_coords_list: | |
| coord['x'] = int(coord['x'] * scale) | |
| coord['y'] = int(coord['y'] * scale) | |
| pos_coordinates = json.dumps(pos_coords_list) | |
| if neg_coordinates: | |
| neg_coords_list = json.loads(neg_coordinates) | |
| for coord in neg_coords_list: | |
| coord['x'] = int(coord['x'] * scale) | |
| coord['y'] = int(coord['y'] * scale) | |
| neg_coordinates = json.dumps(neg_coords_list) | |
| if bboxes_xyxy: | |
| bboxes_list = json.loads(bboxes_xyxy) | |
| new_bboxes = [] | |
| for bbox in bboxes_list: | |
| new_bboxes.append(( | |
| int(bbox[0] * scale), | |
| int(bbox[1] * scale), | |
| int(bbox[2] * scale), | |
| int(bbox[3] * scale) | |
| )) | |
| bboxes_xyxy = json.dumps(new_bboxes) | |
| else: | |
| pil_image_resized = pil_image | |
| scale = 1.0 | |
| with tempfile.NamedTemporaryFile(suffix=".webp", delete=False) as temp_in: | |
| pil_image_resized.save(temp_in.name, format="WEBP", quality=80) | |
| temp_in_path = temp_in.name | |
| # Execute segmentation via Client | |
| result_path = await run_in_threadpool( | |
| sam_client.predict, | |
| handle_file(temp_in_path), | |
| pos_coordinates, | |
| neg_coordinates, | |
| bboxes_xyxy, | |
| api_name="/segment" | |
| ) | |
| os.unlink(temp_in_path) | |
| if isinstance(result_path, (list, tuple)): | |
| result_path = result_path[0] | |
| if not result_path or not os.path.exists(result_path): | |
| raise RuntimeError("Client returned invalid result path") | |
| mask_pil = Image.open(result_path) | |
| if mask_pil.mode != 'L': | |
| mask_pil = mask_pil.convert('L') | |
| pil_image = pil_image.convert("RGB") | |
| if pil_image.size != mask_pil.size: | |
| mask_pil = mask_pil.resize(pil_image.size, Image.NEAREST) | |
| r, g, b = pil_image.split() | |
| res_pil = Image.merge("RGBA", (r, g, b, mask_pil)) | |
| mask_tensor = torch.from_numpy(np.array(mask_pil) / 255.0).float().unsqueeze(0) | |
| mask_bbox = get_mask_bbox(mask_tensor) | |
| if mask_bbox: | |
| x_min, y_min, x_max, y_max = mask_bbox | |
| seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max} | |
| else: | |
| seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0} | |
| print(seg_bbox) | |
| buffered = io.BytesIO() | |
| res_pil.save(buffered, format="PNG") | |
| image_base64_res = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return { | |
| "error": False, | |
| "segmentation_image": "data:image/png;base64," + image_base64_res, | |
| "segmentation_bbox": seg_bbox | |
| } | |
| except Exception as e: | |
| print(f"Error in segmentation: {e}") | |
| return {"error": str(e)} | |
| # Mount the Gradio app | |
| demo.queue(default_concurrency_limit=20, max_size=40) | |
| app = gr.mount_gradio_app(app, demo, path="/", root_path="/demo") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |