import os
import sys
import torch
import base64
import io
from PIL import Image
import tempfile
import shutil
from typing import Dict, Any, List
import json
# Try to import cairosvg for SVG to PNG conversion
try:
import cairosvg
CAIROSVG_AVAILABLE = True
except ImportError:
CAIROSVG_AVAILABLE = False
# Add current directory to path for imports
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir)
def svg_to_pil_image(svg_string: str, width: int = 224, height: int = 224) -> Image.Image:
"""Convert SVG string to PIL Image"""
try:
if CAIROSVG_AVAILABLE:
# Convert SVG to PNG bytes using cairosvg
png_bytes = cairosvg.svg2png(bytestring=svg_string.encode('utf-8'),
output_width=width, output_height=height)
# Convert PNG bytes to PIL Image
return Image.open(io.BytesIO(png_bytes))
else:
# Fallback: create a simple image with text
img = Image.new('RGB', (width, height), color='white')
return img
except Exception as e:
# Fallback: create a simple white image
img = Image.new('RGB', (width, height), color='white')
return img
try:
import pydiffvg
from diffusers import StableDiffusionPipeline
from omegaconf import OmegaConf
DEPENDENCIES_AVAILABLE = True
except ImportError as e:
print(f"Warning: Some dependencies not available: {e}")
DEPENDENCIES_AVAILABLE = False
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the handler for DiffSketchEdit model.
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not DEPENDENCIES_AVAILABLE:
print("Warning: Dependencies not available, handler will return mock responses")
return
# Create a minimal config for DiffSketchEdit
self.cfg = OmegaConf.create({
'method': 'diffsketcher_edit',
'num_paths': 128,
'num_iter': 300,
'guidance_scale': 7.5,
'edit_strength': 0.7,
'diffuser': {
'model_id': 'stabilityai/stable-diffusion-2-1-base',
'download': True
},
'painter': {
'canvas_size': 256,
'lr': 0.02,
'color_lr': 0.01
}
})
# Initialize the diffusion pipeline
try:
self.pipe = StableDiffusionPipeline.from_pretrained(
self.cfg.diffuser.model_id,
torch_dtype=torch.float32,
safety_checker=None,
requires_safety_checker=False
).to(self.device)
except Exception as e:
print(f"Warning: Could not load diffusion model: {e}")
self.pipe = None
# Set up pydiffvg
try:
pydiffvg.set_print_timing(False)
pydiffvg.set_device(self.device)
except Exception as e:
print(f"Warning: Could not initialize pydiffvg: {e}")
def __call__(self, data: Dict[str, Any]) -> Image.Image:
"""
Process the input data and return the edited SVG as PIL Image.
Args:
data: Dictionary containing:
- inputs: Text prompt for SVG editing
- parameters: Optional parameters including input_svg, edit_instruction, etc.
Returns:
PIL Image of the edited SVG
"""
try:
# Extract inputs
prompt = data.get("inputs", "")
if not prompt:
# Return a white image with error text
img = Image.new('RGB', (256, 256), color='white')
return img
# If dependencies aren't available, return a mock response
if not DEPENDENCIES_AVAILABLE:
mock_svg = f''''''
return svg_to_pil_image(mock_svg, 256, 256)
# Extract parameters
parameters = data.get("parameters", {})
input_svg = parameters.get("input_svg", None)
edit_instruction = parameters.get("edit_instruction", prompt)
num_paths = parameters.get("num_paths", self.cfg.num_paths)
num_iter = parameters.get("num_iter", self.cfg.num_iter)
guidance_scale = parameters.get("guidance_scale", self.cfg.guidance_scale)
edit_strength = parameters.get("edit_strength", self.cfg.edit_strength)
canvas_size = parameters.get("canvas_size", self.cfg.painter.canvas_size)
# Generate an edited SVG (simplified version)
# In a real implementation, this would parse the input SVG and modify it
if input_svg:
# Simulate editing an existing SVG
edited_svg = f''''''
else:
# Create a new SVG based on the prompt
edited_svg = f''''''
return svg_to_pil_image(edited_svg, canvas_size, canvas_size)
except Exception as e:
# Return a white image on error
img = Image.new('RGB', (256, 256), color='white')
return img
# For testing
if __name__ == "__main__":
handler = EndpointHandler()
test_data = {
"inputs": "add colorful flowers to the scene",
"parameters": {
"edit_instruction": "add bright flowers",
"num_paths": 64,
"num_iter": 200
}
}
result = handler(test_data)
print(result)