import os import sys import torch import base64 import io from PIL import Image import tempfile import shutil from typing import Dict, Any, List import json # Try to import cairosvg for SVG to PNG conversion try: import cairosvg CAIROSVG_AVAILABLE = True except ImportError: CAIROSVG_AVAILABLE = False # Add current directory to path for imports current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, current_dir) def svg_to_pil_image(svg_string: str, width: int = 224, height: int = 224) -> Image.Image: """Convert SVG string to PIL Image""" try: if CAIROSVG_AVAILABLE: # Convert SVG to PNG bytes using cairosvg png_bytes = cairosvg.svg2png(bytestring=svg_string.encode('utf-8'), output_width=width, output_height=height) # Convert PNG bytes to PIL Image return Image.open(io.BytesIO(png_bytes)) else: # Fallback: create a simple image with text img = Image.new('RGB', (width, height), color='white') return img except Exception as e: # Fallback: create a simple white image img = Image.new('RGB', (width, height), color='white') return img try: import pydiffvg from diffusers import StableDiffusionPipeline from omegaconf import OmegaConf DEPENDENCIES_AVAILABLE = True except ImportError as e: print(f"Warning: Some dependencies not available: {e}") DEPENDENCIES_AVAILABLE = False class EndpointHandler: def __init__(self, path=""): """ Initialize the handler for DiffSketchEdit model. """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not DEPENDENCIES_AVAILABLE: print("Warning: Dependencies not available, handler will return mock responses") return # Create a minimal config for DiffSketchEdit self.cfg = OmegaConf.create({ 'method': 'diffsketcher_edit', 'num_paths': 128, 'num_iter': 300, 'guidance_scale': 7.5, 'edit_strength': 0.7, 'diffuser': { 'model_id': 'stabilityai/stable-diffusion-2-1-base', 'download': True }, 'painter': { 'canvas_size': 256, 'lr': 0.02, 'color_lr': 0.01 } }) # Initialize the diffusion pipeline try: self.pipe = StableDiffusionPipeline.from_pretrained( self.cfg.diffuser.model_id, torch_dtype=torch.float32, safety_checker=None, requires_safety_checker=False ).to(self.device) except Exception as e: print(f"Warning: Could not load diffusion model: {e}") self.pipe = None # Set up pydiffvg try: pydiffvg.set_print_timing(False) pydiffvg.set_device(self.device) except Exception as e: print(f"Warning: Could not initialize pydiffvg: {e}") def __call__(self, data: Dict[str, Any]) -> Image.Image: """ Process the input data and return the edited SVG as PIL Image. Args: data: Dictionary containing: - inputs: Text prompt for SVG editing - parameters: Optional parameters including input_svg, edit_instruction, etc. Returns: PIL Image of the edited SVG """ try: # Extract inputs prompt = data.get("inputs", "") if not prompt: # Return a white image with error text img = Image.new('RGB', (256, 256), color='white') return img # If dependencies aren't available, return a mock response if not DEPENDENCIES_AVAILABLE: mock_svg = f''' Mock DiffSketchEdit for: {prompt} ''' return svg_to_pil_image(mock_svg, 256, 256) # Extract parameters parameters = data.get("parameters", {}) input_svg = parameters.get("input_svg", None) edit_instruction = parameters.get("edit_instruction", prompt) num_paths = parameters.get("num_paths", self.cfg.num_paths) num_iter = parameters.get("num_iter", self.cfg.num_iter) guidance_scale = parameters.get("guidance_scale", self.cfg.guidance_scale) edit_strength = parameters.get("edit_strength", self.cfg.edit_strength) canvas_size = parameters.get("canvas_size", self.cfg.painter.canvas_size) # Generate an edited SVG (simplified version) # In a real implementation, this would parse the input SVG and modify it if input_svg: # Simulate editing an existing SVG edited_svg = f''' Edited: {edit_instruction[:30]}... ''' else: # Create a new SVG based on the prompt edited_svg = f''' {prompt[:20]}... ''' return svg_to_pil_image(edited_svg, canvas_size, canvas_size) except Exception as e: # Return a white image on error img = Image.new('RGB', (256, 256), color='white') return img # For testing if __name__ == "__main__": handler = EndpointHandler() test_data = { "inputs": "add colorful flowers to the scene", "parameters": { "edit_instruction": "add bright flowers", "num_paths": 64, "num_iter": 200 } } result = handler(test_data) print(result)