Spaces:
Running
Running
| """ | |
| TimeLapseForge - Universal API Provider Layer v2.2 | |
| All API client imports are LAZY. | |
| Smart prompt truncation per provider. | |
| """ | |
| import os | |
| import io | |
| import time | |
| import base64 | |
| import tempfile | |
| import requests | |
| from PIL import Image | |
| from typing import Optional, Dict, List, Any, Tuple | |
| from abc import ABC, abstractmethod | |
| def _safe_import(package_name): | |
| try: | |
| import importlib | |
| return importlib.import_module(package_name) | |
| except ImportError: | |
| return None | |
| def _require_import(package_name, pip_name=None): | |
| mod = _safe_import(package_name) | |
| if mod is None: | |
| pip = pip_name or package_name | |
| raise ImportError( | |
| "Package '" + pip + "' is not installed. " | |
| "Add it to requirements.txt or use a different provider." | |
| ) | |
| return mod | |
| # ============================================ | |
| # SMART PROMPT TRUNCATOR | |
| # ============================================ | |
| def smart_truncate(text, max_length, preserve_end=True): | |
| """ | |
| Intelligently truncate a prompt to fit within API limits. | |
| Preserves the most important parts: subject description and style suffix. | |
| """ | |
| if not text or len(text) <= max_length: | |
| return text | |
| # Strategy: keep first part (subject) and last part (style keywords) | |
| if preserve_end: | |
| # Find the last comma-separated style section | |
| parts = text.rsplit(", ", 1) | |
| if len(parts) == 2 and len(parts[1]) < max_length // 3: | |
| suffix = ", " + parts[1] | |
| available = max_length - len(suffix) - 5 # 5 for " ... " | |
| if available > 100: | |
| return text[:available] + " ... " + suffix | |
| # Simple truncation with clean cut at word boundary | |
| truncated = text[:max_length - 3] | |
| last_space = truncated.rfind(" ") | |
| if last_space > max_length // 2: | |
| truncated = truncated[:last_space] | |
| return truncated + "..." | |
| def split_prompt_parts(full_prompt): | |
| """ | |
| Split a long prompt into core subject and style modifiers. | |
| Returns (core, style) where style is the reusable suffix. | |
| """ | |
| # Common style keywords that appear at the end | |
| style_markers = [ | |
| "photorealistic", "cinematic", "4K", "8K", "detailed", | |
| "shot on", "lens", "lighting", "consistent", "camera", | |
| "high quality", "professional", "dramatic", | |
| ] | |
| # Try to find where style section starts | |
| lower = full_prompt.lower() | |
| best_split = len(full_prompt) | |
| for marker in style_markers: | |
| idx = lower.rfind(marker) | |
| if idx > len(full_prompt) // 2: | |
| # Find the comma before this marker | |
| comma_idx = full_prompt.rfind(", ", 0, idx) | |
| if comma_idx > len(full_prompt) // 3: | |
| best_split = min(best_split, comma_idx) | |
| if best_split < len(full_prompt): | |
| core = full_prompt[:best_split].strip().rstrip(",") | |
| style = full_prompt[best_split:].strip().lstrip(",").strip() | |
| return core, style | |
| return full_prompt, "" | |
| # ============================================ | |
| # BASE PROVIDER CLASS | |
| # ============================================ | |
| class BaseProvider(ABC): | |
| name = "base" | |
| display_name = "Base Provider" | |
| website = "" | |
| supports_img2img = False | |
| supports_negative_prompt = True | |
| default_model = "" | |
| available_models = [] | |
| requires_package = "" | |
| max_prompt_length = 10000 # Default generous limit | |
| def __init__(self, api_key=""): | |
| self.api_key = api_key.strip() | |
| def _truncate(self, prompt, max_len=None): | |
| """Truncate prompt to fit provider's limit.""" | |
| limit = max_len or self.max_prompt_length | |
| return smart_truncate(prompt, limit) | |
| def generate_image( | |
| self, prompt, negative_prompt="", | |
| width=1024, height=1024, | |
| seed=None, model=None, **kwargs, | |
| ): | |
| pass | |
| def img2img( | |
| self, prompt, image, strength=0.4, | |
| negative_prompt="", seed=None, | |
| model=None, **kwargs, | |
| ): | |
| return self.generate_image( | |
| prompt=prompt, negative_prompt=negative_prompt, | |
| width=image.width, height=image.height, | |
| seed=seed, model=model, **kwargs, | |
| ) | |
| def _image_to_base64(img, fmt="PNG"): | |
| buf = io.BytesIO() | |
| img.save(buf, format=fmt) | |
| return base64.b64encode(buf.getvalue()).decode("utf-8") | |
| def _base64_to_image(b64): | |
| data = base64.b64decode(b64) | |
| return Image.open(io.BytesIO(data)).convert("RGB") | |
| def _url_to_image(url): | |
| resp = requests.get(url, timeout=120) | |
| resp.raise_for_status() | |
| return Image.open(io.BytesIO(resp.content)).convert("RGB") | |
| def _bytes_to_image(data): | |
| return Image.open(io.BytesIO(data)).convert("RGB") | |
| # ============================================ | |
| # 1. OPENAI (DALL-E 3 / gpt-image-1) | |
| # ============================================ | |
| class OpenAIProvider(BaseProvider): | |
| name = "openai" | |
| display_name = "OpenAI (DALL-E 3 / gpt-image-1)" | |
| website = "https://platform.openai.com/api-keys" | |
| supports_img2img = False | |
| supports_negative_prompt = False | |
| default_model = "dall-e-3" | |
| available_models = ["dall-e-3", "dall-e-2", "gpt-image-1"] | |
| requires_package = "openai" | |
| max_prompt_length = 3900 # DALL-E 3 limit is 4000 | |
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |
| seed=None, model=None, **kwargs): | |
| openai_mod = _require_import("openai") | |
| client = openai_mod.OpenAI(api_key=self.api_key) | |
| model = model or self.default_model | |
| # Set correct limit per model | |
| if model == "dall-e-2": | |
| limit = 900 # DALL-E 2 limit is 1000 | |
| elif model == "gpt-image-1": | |
| limit = 32000 # gpt-image-1 has much higher limit | |
| else: | |
| limit = 3900 # DALL-E 3 | |
| safe_prompt = self._truncate(prompt, limit) | |
| size_map = { | |
| (1024, 1024): "1024x1024", (1792, 1024): "1792x1024", | |
| (1024, 1792): "1024x1792", (512, 512): "512x512", | |
| (256, 256): "256x256", | |
| } | |
| size = size_map.get((width, height), "1024x1024") | |
| if model == "gpt-image-1": | |
| response = client.images.generate( | |
| model="gpt-image-1", prompt=safe_prompt, n=1, size=size, | |
| ) | |
| if hasattr(response.data[0], 'b64_json') and response.data[0].b64_json: | |
| return self._base64_to_image(response.data[0].b64_json) | |
| return self._url_to_image(response.data[0].url) | |
| else: | |
| api_kwargs = dict( | |
| model=model, prompt=safe_prompt, n=1, size=size, | |
| response_format="b64_json", | |
| ) | |
| if model == "dall-e-3": | |
| api_kwargs["quality"] = kwargs.get("quality", "hd") | |
| api_kwargs["style"] = kwargs.get("style", "natural") | |
| response = client.images.generate(**api_kwargs) | |
| return self._base64_to_image(response.data[0].b64_json) | |
| # ============================================ | |
| # 2. STABILITY AI | |
| # ============================================ | |
| class StabilityProvider(BaseProvider): | |
| name = "stability" | |
| display_name = "Stability AI (SD3 / SDXL)" | |
| website = "https://platform.stability.ai/account/keys" | |
| supports_img2img = True | |
| supports_negative_prompt = True | |
| default_model = "sd3.5-large" | |
| available_models = [ | |
| "sd3.5-large", "sd3.5-large-turbo", "sd3.5-medium", | |
| "sd3-large", "sd3-large-turbo", "sd3-medium", | |
| "stable-image-core", "stable-image-ultra", | |
| ] | |
| requires_package = "" | |
| max_prompt_length = 10000 | |
| API_BASE = "https://api.stability.ai/v2beta" | |
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |
| seed=None, model=None, **kwargs): | |
| model = model or self.default_model | |
| safe_prompt = self._truncate(prompt) | |
| headers = {"Authorization": "Bearer " + self.api_key, "Accept": "image/*"} | |
| data = {"prompt": safe_prompt, "output_format": "png", "width": width, "height": height} | |
| if negative_prompt: | |
| data["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |
| if seed is not None: | |
| data["seed"] = seed | |
| if "stable-image" in model: | |
| url = self.API_BASE + "/stable-image/generate/" + model.replace("stable-image-", "") | |
| else: | |
| url = self.API_BASE + "/stable-image/generate/sd3" | |
| data["model"] = model | |
| resp = requests.post(url, headers=headers, files={"none": ""}, data=data, timeout=120) | |
| resp.raise_for_status() | |
| return self._bytes_to_image(resp.content) | |
| def img2img(self, prompt, image, strength=0.4, negative_prompt="", | |
| seed=None, model=None, **kwargs): | |
| safe_prompt = self._truncate(prompt) | |
| headers = {"Authorization": "Bearer " + self.api_key, "Accept": "image/*"} | |
| buf = io.BytesIO() | |
| image.save(buf, format="PNG") | |
| buf.seek(0) | |
| data = { | |
| "prompt": safe_prompt, "strength": strength, | |
| "output_format": "png", "mode": "image-to-image", | |
| } | |
| if negative_prompt: | |
| data["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |
| if seed is not None: | |
| data["seed"] = seed | |
| files = {"image": ("input.png", buf, "image/png")} | |
| url = self.API_BASE + "/stable-image/generate/sd3" | |
| resp = requests.post(url, headers=headers, files=files, data=data, timeout=120) | |
| resp.raise_for_status() | |
| return self._bytes_to_image(resp.content) | |
| # ============================================ | |
| # 3. REPLICATE | |
| # ============================================ | |
| class ReplicateProvider(BaseProvider): | |
| name = "replicate" | |
| display_name = "Replicate (Flux / SDXL / Any)" | |
| website = "https://replicate.com/account/api-tokens" | |
| supports_img2img = True | |
| supports_negative_prompt = True | |
| default_model = "black-forest-labs/flux-1.1-pro" | |
| available_models = [ | |
| "black-forest-labs/flux-1.1-pro", | |
| "black-forest-labs/flux-schnell", | |
| "black-forest-labs/flux-dev", | |
| "stability-ai/sdxl:latest", | |
| "stability-ai/stable-diffusion-3.5-large", | |
| "bytedance/sdxl-lightning-4step:latest", | |
| ] | |
| requires_package = "replicate" | |
| max_prompt_length = 10000 | |
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |
| seed=None, model=None, **kwargs): | |
| replicate_mod = _require_import("replicate") | |
| client = replicate_mod.Client(api_token=self.api_key) | |
| model_id = model or self.default_model | |
| safe_prompt = self._truncate(prompt) | |
| input_params = {"prompt": safe_prompt, "width": width, "height": height} | |
| if negative_prompt and "flux" not in model_id.lower(): | |
| input_params["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |
| if seed is not None: | |
| input_params["seed"] = seed | |
| output = client.run(model_id, input=input_params) | |
| if isinstance(output, list): | |
| url = str(output[0]) | |
| elif hasattr(output, 'url'): | |
| url = output.url | |
| else: | |
| url = str(output) | |
| return self._url_to_image(url) | |
| def img2img(self, prompt, image, strength=0.4, negative_prompt="", | |
| seed=None, model=None, **kwargs): | |
| replicate_mod = _require_import("replicate") | |
| client = replicate_mod.Client(api_token=self.api_key) | |
| model_id = model or "stability-ai/sdxl:latest" | |
| safe_prompt = self._truncate(prompt) | |
| buf = io.BytesIO() | |
| image.save(buf, format="PNG") | |
| buf.seek(0) | |
| input_params = {"prompt": safe_prompt, "image": buf, "prompt_strength": strength} | |
| if negative_prompt: | |
| input_params["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |
| if seed is not None: | |
| input_params["seed"] = seed | |
| output = client.run(model_id, input=input_params) | |
| url = str(output[0]) if isinstance(output, list) else str(output) | |
| return self._url_to_image(url) | |
| # ============================================ | |
| # 4. TOGETHER AI | |
| # ============================================ | |
| class TogetherProvider(BaseProvider): | |
| name = "together" | |
| display_name = "Together AI (Flux / SDXL)" | |
| website = "https://api.together.xyz/settings/api-keys" | |
| supports_img2img = False | |
| supports_negative_prompt = True | |
| default_model = "black-forest-labs/FLUX.1.1-pro" | |
| available_models = [ | |
| "black-forest-labs/FLUX.1.1-pro", | |
| "black-forest-labs/FLUX.1-schnell-Free", | |
| "black-forest-labs/FLUX.1-dev", | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| ] | |
| requires_package = "together" | |
| max_prompt_length = 10000 | |
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |
| seed=None, model=None, **kwargs): | |
| together_mod = _require_import("together") | |
| client = together_mod.Together(api_key=self.api_key) | |
| model_id = model or self.default_model | |
| safe_prompt = self._truncate(prompt) | |
| params = dict(model=model_id, prompt=safe_prompt, width=width, height=height, | |
| steps=kwargs.get("steps", 28), n=1, response_format="b64_json") | |
| if negative_prompt: | |
| params["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |
| if seed is not None: | |
| params["seed"] = seed | |
| response = client.images.generate(**params) | |
| return self._base64_to_image(response.data[0].b64_json) | |
| # ============================================ | |
| # 5. FAL.AI | |
| # ============================================ | |
| class FalProvider(BaseProvider): | |
| name = "fal" | |
| display_name = "Fal.ai (Flux Pro / Fast SDXL)" | |
| website = "https://fal.ai/dashboard/keys" | |
| supports_img2img = True | |
| supports_negative_prompt = True | |
| default_model = "fal-ai/flux-pro/v1.1" | |
| available_models = [ | |
| "fal-ai/flux-pro/v1.1", "fal-ai/flux/dev", "fal-ai/flux/schnell", | |
| "fal-ai/flux-realism", "fal-ai/fast-sdxl", | |
| "fal-ai/stable-diffusion-v35-large", "fal-ai/recraft-v3", | |
| ] | |
| requires_package = "fal_client" | |
| max_prompt_length = 10000 | |
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |
| seed=None, model=None, **kwargs): | |
| fal_client = _require_import("fal_client", "fal-client") | |
| os.environ["FAL_KEY"] = self.api_key | |
| model_id = model or self.default_model | |
| safe_prompt = self._truncate(prompt) | |
| arguments = {"prompt": safe_prompt, "image_size": {"width": width, "height": height}, "num_images": 1} | |
| if negative_prompt: | |
| arguments["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |
| if seed is not None: | |
| arguments["seed"] = seed | |
| result = fal_client.subscribe(model_id, arguments=arguments) | |
| images = result.get("images", []) | |
| if images: | |
| return self._url_to_image(images[0]["url"]) | |
| raise ValueError("No image returned from Fal.ai") | |
| def img2img(self, prompt, image, strength=0.4, negative_prompt="", | |
| seed=None, model=None, **kwargs): | |
| fal_client = _require_import("fal_client", "fal-client") | |
| os.environ["FAL_KEY"] = self.api_key | |
| safe_prompt = self._truncate(prompt) | |
| b64 = self._image_to_base64(image) | |
| data_uri = "data:image/png;base64," + b64 | |
| arguments = {"prompt": safe_prompt, "image_url": data_uri, "strength": strength, "num_images": 1} | |
| if negative_prompt: | |
| arguments["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |
| if seed is not None: | |
| arguments["seed"] = seed | |
| model_id = model or "fal-ai/flux/dev/image-to-image" | |
| result = fal_client.subscribe(model_id, arguments=arguments) | |
| images = result.get("images", []) | |
| if images: | |
| return self._url_to_image(images[0]["url"]) | |
| raise ValueError("No image from Fal.ai img2img") | |
| # ============================================ | |
| # 6. GOOGLE GEMINI | |
| # ============================================ | |
| class GoogleGeminiProvider(BaseProvider): | |
| name = "google" | |
| display_name = "Google Gemini (Imagen 3)" | |
| website = "https://aistudio.google.com/apikey" | |
| supports_img2img = False | |
| supports_negative_prompt = True | |
| default_model = "imagen-3.0-generate-002" | |
| available_models = ["imagen-3.0-generate-002", "imagen-3.0-fast-generate-001"] | |
| requires_package = "google.generativeai" | |
| max_prompt_length = 5000 | |
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |
| seed=None, model=None, **kwargs): | |
| genai = _require_import("google.generativeai", "google-generativeai") | |
| genai.configure(api_key=self.api_key) | |
| model_id = model or self.default_model | |
| safe_prompt = self._truncate(prompt) | |
| imagen = genai.ImageGenerationModel(model_id) | |
| params = dict(prompt=safe_prompt, number_of_images=1) | |
| if negative_prompt: | |
| params["negative_prompt"] = smart_truncate(negative_prompt, 2000) | |
| response = imagen.generate_images(**params) | |
| if response.images: | |
| return response.images[0]._pil_image.convert("RGB") | |
| raise ValueError("No image returned from Imagen") | |
| # ============================================ | |
| # 7. HUGGING FACE INFERENCE API | |
| # ============================================ | |
| class HuggingFaceProvider(BaseProvider): | |
| name = "huggingface" | |
| display_name = "HuggingFace Inference API" | |
| website = "https://huggingface.co/settings/tokens" | |
| supports_img2img = False | |
| supports_negative_prompt = True | |
| default_model = "black-forest-labs/FLUX.1-schnell" | |
| available_models = [ | |
| "black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-dev", | |
| "stabilityai/stable-diffusion-xl-base-1.0", | |
| "stabilityai/stable-diffusion-3.5-large", | |
| "runwayml/stable-diffusion-v1-5", | |
| ] | |
| requires_package = "" | |
| max_prompt_length = 10000 | |
| API_BASE = "https://api-inference.huggingface.co/models" | |
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |
| seed=None, model=None, **kwargs): | |
| model_id = model or self.default_model | |
| url = self.API_BASE + "/" + model_id | |
| headers = {"Authorization": "Bearer " + self.api_key} | |
| safe_prompt = self._truncate(prompt) | |
| payload = {"inputs": safe_prompt, "parameters": {"width": width, "height": height}} | |
| if negative_prompt: | |
| payload["parameters"]["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |
| if seed is not None: | |
| payload["parameters"]["seed"] = seed | |
| resp = requests.post(url, headers=headers, json=payload, timeout=180) | |
| if resp.status_code == 503: | |
| time.sleep(20) | |
| resp = requests.post(url, headers=headers, json=payload, timeout=180) | |
| resp.raise_for_status() | |
| return self._bytes_to_image(resp.content) | |
| # ============================================ | |
| # 8. xAI GROK | |
| # ============================================ | |
| class XAIProvider(BaseProvider): | |
| name = "xai" | |
| display_name = "xAI Grok (Aurora)" | |
| website = "https://console.x.ai/team/default/api-keys" | |
| supports_img2img = False | |
| supports_negative_prompt = False | |
| default_model = "grok-2-image" | |
| available_models = ["grok-2-image"] | |
| requires_package = "openai" | |
| max_prompt_length = 4000 | |
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |
| seed=None, model=None, **kwargs): | |
| openai_mod = _require_import("openai") | |
| client = openai_mod.OpenAI(api_key=self.api_key, base_url="https://api.x.ai/v1") | |
| safe_prompt = self._truncate(prompt) | |
| response = client.images.generate( | |
| model=model or self.default_model, prompt=safe_prompt, | |
| n=1, response_format="b64_json", size="1024x1024", | |
| ) | |
| return self._base64_to_image(response.data[0].b64_json) | |
| # ============================================ | |
| # 9. FIREWORKS AI | |
| # ============================================ | |
| class FireworksProvider(BaseProvider): | |
| name = "fireworks" | |
| display_name = "Fireworks AI (Flux / SD)" | |
| website = "https://fireworks.ai/account/api-keys" | |
| supports_img2img = False | |
| supports_negative_prompt = True | |
| default_model = "accounts/fireworks/models/flux-1-1-pro" | |
| available_models = [ | |
| "accounts/fireworks/models/flux-1-1-pro", | |
| "accounts/fireworks/models/flux-1-schnell-fp8", | |
| "accounts/fireworks/models/flux-1-dev-fp8", | |
| "accounts/fireworks/models/stable-diffusion-xl-1024-v1-0", | |
| ] | |
| requires_package = "" | |
| max_prompt_length = 10000 | |
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |
| seed=None, model=None, **kwargs): | |
| url = "https://api.fireworks.ai/inference/v1/images/generations" | |
| headers = {"Authorization": "Bearer " + self.api_key, "Content-Type": "application/json"} | |
| safe_prompt = self._truncate(prompt) | |
| payload = { | |
| "model": model or self.default_model, "prompt": safe_prompt, | |
| "n": 1, "size": str(width) + "x" + str(height), "response_format": "b64_json", | |
| } | |
| if negative_prompt: | |
| payload["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |
| if seed is not None: | |
| payload["seed"] = seed | |
| resp = requests.post(url, headers=headers, json=payload, timeout=120) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| return self._base64_to_image(data["data"][0]["b64_json"]) | |
| # ============================================ | |
| # 10. IDEOGRAM | |
| # ============================================ | |
| class IdeogramProvider(BaseProvider): | |
| name = "ideogram" | |
| display_name = "Ideogram (v2 / v2-turbo)" | |
| website = "https://ideogram.ai/manage-api" | |
| supports_img2img = False | |
| supports_negative_prompt = True | |
| default_model = "V_2" | |
| available_models = ["V_2", "V_2_TURBO", "V_1", "V_1_TURBO"] | |
| requires_package = "" | |
| max_prompt_length = 10000 | |
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |
| seed=None, model=None, **kwargs): | |
| url = "https://api.ideogram.ai/generate" | |
| headers = {"Api-Key": self.api_key, "Content-Type": "application/json"} | |
| safe_prompt = self._truncate(prompt) | |
| payload = { | |
| "image_request": { | |
| "prompt": safe_prompt, "model": model or self.default_model, | |
| "magic_prompt_option": "AUTO", "aspect_ratio": "ASPECT_1_1", | |
| } | |
| } | |
| if negative_prompt: | |
| payload["image_request"]["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |
| if seed is not None: | |
| payload["image_request"]["seed"] = seed | |
| resp = requests.post(url, headers=headers, json=payload, timeout=120) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| return self._url_to_image(data["data"][0]["url"]) | |
| # ============================================ | |
| # 11. LEONARDO AI | |
| # ============================================ | |
| class LeonardoProvider(BaseProvider): | |
| name = "leonardo" | |
| display_name = "Leonardo AI" | |
| website = "https://app.leonardo.ai/api-access" | |
| supports_img2img = False | |
| supports_negative_prompt = True | |
| default_model = "6b645e3a-d64f-4341-a6d8-7a3690fbf042" | |
| available_models = [ | |
| "6b645e3a-d64f-4341-a6d8-7a3690fbf042", | |
| "aa77f04e-3eec-4034-9c07-d0f619684628", | |
| "1e60896f-3c26-4296-8ecc-53e2afecc132", | |
| ] | |
| requires_package = "" | |
| max_prompt_length = 10000 | |
| API_BASE = "https://cloud.leonardo.ai/api/rest/v1" | |
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |
| seed=None, model=None, **kwargs): | |
| headers = {"Authorization": "Bearer " + self.api_key, "Content-Type": "application/json"} | |
| safe_prompt = self._truncate(prompt) | |
| payload = { | |
| "prompt": safe_prompt, "modelId": model or self.default_model, | |
| "width": width, "height": height, "num_images": 1, | |
| } | |
| if negative_prompt: | |
| payload["negative_prompt"] = smart_truncate(negative_prompt, 5000) | |
| if seed is not None: | |
| payload["seed"] = seed | |
| resp = requests.post(self.API_BASE + "/generations", headers=headers, json=payload, timeout=60) | |
| resp.raise_for_status() | |
| gen_id = resp.json()["sdGenerationJob"]["generationId"] | |
| for _ in range(30): | |
| time.sleep(5) | |
| poll = requests.get(self.API_BASE + "/generations/" + gen_id, headers=headers, timeout=30) | |
| poll.raise_for_status() | |
| gen = poll.json().get("generations_by_pk", {}) | |
| if gen.get("status") == "COMPLETE": | |
| images = gen.get("generated_images", []) | |
| if images: | |
| return self._url_to_image(images[0]["url"]) | |
| raise TimeoutError("Leonardo generation timed out") | |
| # ============================================ | |
| # 12. CUSTOM OPENAI-COMPATIBLE | |
| # ============================================ | |
| class CustomOpenAIProvider(BaseProvider): | |
| name = "custom_openai" | |
| display_name = "Custom OpenAI-Compatible API" | |
| website = "" | |
| supports_img2img = False | |
| supports_negative_prompt = False | |
| default_model = "dall-e-3" | |
| available_models = ["dall-e-3", "dall-e-2", "custom"] | |
| requires_package = "openai" | |
| max_prompt_length = 3900 | |
| def __init__(self, api_key="", base_url=""): | |
| super().__init__(api_key) | |
| self.base_url = base_url.strip().rstrip("/") if base_url else "" | |
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |
| seed=None, model=None, **kwargs): | |
| openai_mod = _require_import("openai") | |
| ck = {"api_key": self.api_key} | |
| if self.base_url: | |
| ck["base_url"] = self.base_url | |
| client = openai_mod.OpenAI(**ck) | |
| safe_prompt = self._truncate(prompt) | |
| response = client.images.generate( | |
| model=model or self.default_model, prompt=safe_prompt, | |
| n=1, size=str(width) + "x" + str(height), response_format="b64_json", | |
| ) | |
| return self._base64_to_image(response.data[0].b64_json) | |
| # ============================================ | |
| # 13. DIRECT URL API | |
| # ============================================ | |
| class DirectURLProvider(BaseProvider): | |
| name = "direct_url" | |
| display_name = "Direct URL API (Any REST Endpoint)" | |
| website = "" | |
| supports_img2img = False | |
| supports_negative_prompt = True | |
| default_model = "custom" | |
| available_models = ["custom"] | |
| requires_package = "" | |
| max_prompt_length = 50000 | |
| def __init__(self, api_key="", endpoint_url=""): | |
| super().__init__(api_key) | |
| self.endpoint_url = endpoint_url.strip() | |
| def generate_image(self, prompt, negative_prompt="", width=1024, height=1024, | |
| seed=None, model=None, **kwargs): | |
| if not self.endpoint_url: | |
| raise ValueError("No endpoint URL provided") | |
| headers = {"Authorization": "Bearer " + self.api_key, "Content-Type": "application/json"} | |
| safe_prompt = self._truncate(prompt) | |
| payload = {"prompt": safe_prompt, "width": width, "height": height} | |
| if negative_prompt: | |
| payload["negative_prompt"] = smart_truncate(negative_prompt, 10000) | |
| if seed is not None: | |
| payload["seed"] = seed | |
| if model and model != "custom": | |
| payload["model"] = model | |
| resp = requests.post(self.endpoint_url, headers=headers, json=payload, timeout=180) | |
| resp.raise_for_status() | |
| ct = resp.headers.get("Content-Type", "") | |
| if "image" in ct: | |
| return self._bytes_to_image(resp.content) | |
| data = resp.json() | |
| for key in ["images", "data", "output", "result"]: | |
| if key in data: | |
| item = data[key] | |
| if isinstance(item, list): | |
| item = item[0] | |
| if isinstance(item, dict): | |
| for subkey in ["b64_json", "url", "image"]: | |
| if subkey in item: | |
| val = item[subkey] | |
| if isinstance(val, str) and val.startswith("http"): | |
| return self._url_to_image(val) | |
| return self._base64_to_image(val) | |
| if isinstance(item, str): | |
| if item.startswith("http"): | |
| return self._url_to_image(item) | |
| return self._base64_to_image(item) | |
| raise ValueError("Could not parse image from API response") | |
| # ============================================ | |
| # PROVIDER REGISTRY | |
| # ============================================ | |
| PROVIDERS = { | |
| "openai": OpenAIProvider, | |
| "stability": StabilityProvider, | |
| "replicate": ReplicateProvider, | |
| "together": TogetherProvider, | |
| "fal": FalProvider, | |
| "google": GoogleGeminiProvider, | |
| "huggingface": HuggingFaceProvider, | |
| "xai": XAIProvider, | |
| "fireworks": FireworksProvider, | |
| "ideogram": IdeogramProvider, | |
| "leonardo": LeonardoProvider, | |
| "custom_openai": CustomOpenAIProvider, | |
| "direct_url": DirectURLProvider, | |
| } | |
| PROVIDER_DISPLAY_NAMES = {cls.display_name: key for key, cls in PROVIDERS.items()} | |
| def get_provider(provider_name, api_key, **kwargs): | |
| cls = PROVIDERS.get(provider_name) | |
| if cls is None: | |
| raise ValueError("Unknown provider: " + str(provider_name)) | |
| if provider_name == "custom_openai": | |
| return cls(api_key=api_key, base_url=kwargs.get("base_url", "")) | |
| if provider_name == "direct_url": | |
| return cls(api_key=api_key, endpoint_url=kwargs.get("endpoint_url", "")) | |
| return cls(api_key=api_key) | |
| def get_provider_info(): | |
| info = [] | |
| for key, cls in PROVIDERS.items(): | |
| pkg = cls.requires_package | |
| installed = True | |
| if pkg: | |
| installed = _safe_import(pkg.split(".")[0]) is not None | |
| info.append({ | |
| "name": key, "display_name": cls.display_name, | |
| "website": cls.website, "supports_img2img": cls.supports_img2img, | |
| "default_model": cls.default_model, "available_models": cls.available_models, | |
| "requires_package": pkg, "package_installed": installed, | |
| "max_prompt_length": cls.max_prompt_length, | |
| }) | |
| return info | |
| def get_models_for_provider(provider_name): | |
| cls = PROVIDERS.get(provider_name) | |
| return cls.available_models if cls else [] | |