TimeLapseForge / api_providers.py
Adnan
Update api_providers.py
d7e092a verified
"""
TimeLapseForge - Universal API Provider Layer v2.2
All API client imports are LAZY.
Smart prompt truncation per provider.
"""
import os
import io
import time
import base64
import tempfile
import requests
from PIL import Image
from typing import Optional, Dict, List, Any, Tuple
from abc import ABC, abstractmethod
def _safe_import(package_name):
try:
import importlib
return importlib.import_module(package_name)
except ImportError:
return None
def _require_import(package_name, pip_name=None):
mod = _safe_import(package_name)
if mod is None:
pip = pip_name or package_name
raise ImportError(
"Package '" + pip + "' is not installed. "
"Add it to requirements.txt or use a different provider."
)
return mod
# ============================================
# SMART PROMPT TRUNCATOR
# ============================================
def smart_truncate(text, max_length, preserve_end=True):
"""
Intelligently truncate a prompt to fit within API limits.
Preserves the most important parts: subject description and style suffix.
"""
if not text or len(text) <= max_length:
return text
# Strategy: keep first part (subject) and last part (style keywords)
if preserve_end:
# Find the last comma-separated style section
parts = text.rsplit(", ", 1)
if len(parts) == 2 and len(parts[1]) < max_length // 3:
suffix = ", " + parts[1]
available = max_length - len(suffix) - 5 # 5 for " ... "
if available > 100:
return text[:available] + " ... " + suffix
# Simple truncation with clean cut at word boundary
truncated = text[:max_length - 3]
last_space = truncated.rfind(" ")
if last_space > max_length // 2:
truncated = truncated[:last_space]
return truncated + "..."
def split_prompt_parts(full_prompt):
"""
Split a long prompt into core subject and style modifiers.
Returns (core, style) where style is the reusable suffix.
"""
# Common style keywords that appear at the end
style_markers = [
"photorealistic", "cinematic", "4K", "8K", "detailed",
"shot on", "lens", "lighting", "consistent", "camera",
"high quality", "professional", "dramatic",
]
# Try to find where style section starts
lower = full_prompt.lower()
best_split = len(full_prompt)
for marker in style_markers:
idx = lower.rfind(marker)
if idx > len(full_prompt) // 2:
# Find the comma before this marker
comma_idx = full_prompt.rfind(", ", 0, idx)
if comma_idx > len(full_prompt) // 3:
best_split = min(best_split, comma_idx)
if best_split < len(full_prompt):
core = full_prompt[:best_split].strip().rstrip(",")
style = full_prompt[best_split:].strip().lstrip(",").strip()
return core, style
return full_prompt, ""
# ============================================
# BASE PROVIDER CLASS
# ============================================
class BaseProvider(ABC):
name = "base"
display_name = "Base Provider"
website = ""
supports_img2img = False
supports_negative_prompt = True
default_model = ""
available_models = []
requires_package = ""
max_prompt_length = 10000 # Default generous limit
def __init__(self, api_key=""):
self.api_key = api_key.strip()
def _truncate(self, prompt, max_len=None):
"""Truncate prompt to fit provider's limit."""
limit = max_len or self.max_prompt_length
return smart_truncate(prompt, limit)
@abstractmethod
def generate_image(
self, prompt, negative_prompt="",
width=1024, height=1024,
seed=None, model=None, **kwargs,
):
pass
def img2img(
self, prompt, image, strength=0.4,
negative_prompt="", seed=None,
model=None, **kwargs,
):
return self.generate_image(
prompt=prompt, negative_prompt=negative_prompt,
width=image.width, height=image.height,
seed=seed, model=model, **kwargs,
)
@staticmethod
def _image_to_base64(img, fmt="PNG"):
buf = io.BytesIO()
img.save(buf, format=fmt)
return base64.b64encode(buf.getvalue()).decode("utf-8")
@staticmethod
def _base64_to_image(b64):
data = base64.b64decode(b64)
return Image.open(io.BytesIO(data)).convert("RGB")
@staticmethod
def _url_to_image(url):
resp = requests.get(url, timeout=120)
resp.raise_for_status()
return Image.open(io.BytesIO(resp.content)).convert("RGB")
@staticmethod
def _bytes_to_image(data):
return Image.open(io.BytesIO(data)).convert("RGB")
# ============================================
# 1. OPENAI (DALL-E 3 / gpt-image-1)
# ============================================
class OpenAIProvider(BaseProvider):
name = "openai"
display_name = "OpenAI (DALL-E 3 / gpt-image-1)"
website = "https://platform.openai.com/api-keys"
supports_img2img = False
supports_negative_prompt = False
default_model = "dall-e-3"
available_models = ["dall-e-3", "dall-e-2", "gpt-image-1"]
requires_package = "openai"
max_prompt_length = 3900 # DALL-E 3 limit is 4000
def generate_image(self, prompt, negative_prompt="", width=1024, height=1024,
seed=None, model=None, **kwargs):
openai_mod = _require_import("openai")
client = openai_mod.OpenAI(api_key=self.api_key)
model = model or self.default_model
# Set correct limit per model
if model == "dall-e-2":
limit = 900 # DALL-E 2 limit is 1000
elif model == "gpt-image-1":
limit = 32000 # gpt-image-1 has much higher limit
else:
limit = 3900 # DALL-E 3
safe_prompt = self._truncate(prompt, limit)
size_map = {
(1024, 1024): "1024x1024", (1792, 1024): "1792x1024",
(1024, 1792): "1024x1792", (512, 512): "512x512",
(256, 256): "256x256",
}
size = size_map.get((width, height), "1024x1024")
if model == "gpt-image-1":
response = client.images.generate(
model="gpt-image-1", prompt=safe_prompt, n=1, size=size,
)
if hasattr(response.data[0], 'b64_json') and response.data[0].b64_json:
return self._base64_to_image(response.data[0].b64_json)
return self._url_to_image(response.data[0].url)
else:
api_kwargs = dict(
model=model, prompt=safe_prompt, n=1, size=size,
response_format="b64_json",
)
if model == "dall-e-3":
api_kwargs["quality"] = kwargs.get("quality", "hd")
api_kwargs["style"] = kwargs.get("style", "natural")
response = client.images.generate(**api_kwargs)
return self._base64_to_image(response.data[0].b64_json)
# ============================================
# 2. STABILITY AI
# ============================================
class StabilityProvider(BaseProvider):
name = "stability"
display_name = "Stability AI (SD3 / SDXL)"
website = "https://platform.stability.ai/account/keys"
supports_img2img = True
supports_negative_prompt = True
default_model = "sd3.5-large"
available_models = [
"sd3.5-large", "sd3.5-large-turbo", "sd3.5-medium",
"sd3-large", "sd3-large-turbo", "sd3-medium",
"stable-image-core", "stable-image-ultra",
]
requires_package = ""
max_prompt_length = 10000
API_BASE = "https://api.stability.ai/v2beta"
def generate_image(self, prompt, negative_prompt="", width=1024, height=1024,
seed=None, model=None, **kwargs):
model = model or self.default_model
safe_prompt = self._truncate(prompt)
headers = {"Authorization": "Bearer " + self.api_key, "Accept": "image/*"}
data = {"prompt": safe_prompt, "output_format": "png", "width": width, "height": height}
if negative_prompt:
data["negative_prompt"] = smart_truncate(negative_prompt, 5000)
if seed is not None:
data["seed"] = seed
if "stable-image" in model:
url = self.API_BASE + "/stable-image/generate/" + model.replace("stable-image-", "")
else:
url = self.API_BASE + "/stable-image/generate/sd3"
data["model"] = model
resp = requests.post(url, headers=headers, files={"none": ""}, data=data, timeout=120)
resp.raise_for_status()
return self._bytes_to_image(resp.content)
def img2img(self, prompt, image, strength=0.4, negative_prompt="",
seed=None, model=None, **kwargs):
safe_prompt = self._truncate(prompt)
headers = {"Authorization": "Bearer " + self.api_key, "Accept": "image/*"}
buf = io.BytesIO()
image.save(buf, format="PNG")
buf.seek(0)
data = {
"prompt": safe_prompt, "strength": strength,
"output_format": "png", "mode": "image-to-image",
}
if negative_prompt:
data["negative_prompt"] = smart_truncate(negative_prompt, 5000)
if seed is not None:
data["seed"] = seed
files = {"image": ("input.png", buf, "image/png")}
url = self.API_BASE + "/stable-image/generate/sd3"
resp = requests.post(url, headers=headers, files=files, data=data, timeout=120)
resp.raise_for_status()
return self._bytes_to_image(resp.content)
# ============================================
# 3. REPLICATE
# ============================================
class ReplicateProvider(BaseProvider):
name = "replicate"
display_name = "Replicate (Flux / SDXL / Any)"
website = "https://replicate.com/account/api-tokens"
supports_img2img = True
supports_negative_prompt = True
default_model = "black-forest-labs/flux-1.1-pro"
available_models = [
"black-forest-labs/flux-1.1-pro",
"black-forest-labs/flux-schnell",
"black-forest-labs/flux-dev",
"stability-ai/sdxl:latest",
"stability-ai/stable-diffusion-3.5-large",
"bytedance/sdxl-lightning-4step:latest",
]
requires_package = "replicate"
max_prompt_length = 10000
def generate_image(self, prompt, negative_prompt="", width=1024, height=1024,
seed=None, model=None, **kwargs):
replicate_mod = _require_import("replicate")
client = replicate_mod.Client(api_token=self.api_key)
model_id = model or self.default_model
safe_prompt = self._truncate(prompt)
input_params = {"prompt": safe_prompt, "width": width, "height": height}
if negative_prompt and "flux" not in model_id.lower():
input_params["negative_prompt"] = smart_truncate(negative_prompt, 5000)
if seed is not None:
input_params["seed"] = seed
output = client.run(model_id, input=input_params)
if isinstance(output, list):
url = str(output[0])
elif hasattr(output, 'url'):
url = output.url
else:
url = str(output)
return self._url_to_image(url)
def img2img(self, prompt, image, strength=0.4, negative_prompt="",
seed=None, model=None, **kwargs):
replicate_mod = _require_import("replicate")
client = replicate_mod.Client(api_token=self.api_key)
model_id = model or "stability-ai/sdxl:latest"
safe_prompt = self._truncate(prompt)
buf = io.BytesIO()
image.save(buf, format="PNG")
buf.seek(0)
input_params = {"prompt": safe_prompt, "image": buf, "prompt_strength": strength}
if negative_prompt:
input_params["negative_prompt"] = smart_truncate(negative_prompt, 5000)
if seed is not None:
input_params["seed"] = seed
output = client.run(model_id, input=input_params)
url = str(output[0]) if isinstance(output, list) else str(output)
return self._url_to_image(url)
# ============================================
# 4. TOGETHER AI
# ============================================
class TogetherProvider(BaseProvider):
name = "together"
display_name = "Together AI (Flux / SDXL)"
website = "https://api.together.xyz/settings/api-keys"
supports_img2img = False
supports_negative_prompt = True
default_model = "black-forest-labs/FLUX.1.1-pro"
available_models = [
"black-forest-labs/FLUX.1.1-pro",
"black-forest-labs/FLUX.1-schnell-Free",
"black-forest-labs/FLUX.1-dev",
"stabilityai/stable-diffusion-xl-base-1.0",
]
requires_package = "together"
max_prompt_length = 10000
def generate_image(self, prompt, negative_prompt="", width=1024, height=1024,
seed=None, model=None, **kwargs):
together_mod = _require_import("together")
client = together_mod.Together(api_key=self.api_key)
model_id = model or self.default_model
safe_prompt = self._truncate(prompt)
params = dict(model=model_id, prompt=safe_prompt, width=width, height=height,
steps=kwargs.get("steps", 28), n=1, response_format="b64_json")
if negative_prompt:
params["negative_prompt"] = smart_truncate(negative_prompt, 5000)
if seed is not None:
params["seed"] = seed
response = client.images.generate(**params)
return self._base64_to_image(response.data[0].b64_json)
# ============================================
# 5. FAL.AI
# ============================================
class FalProvider(BaseProvider):
name = "fal"
display_name = "Fal.ai (Flux Pro / Fast SDXL)"
website = "https://fal.ai/dashboard/keys"
supports_img2img = True
supports_negative_prompt = True
default_model = "fal-ai/flux-pro/v1.1"
available_models = [
"fal-ai/flux-pro/v1.1", "fal-ai/flux/dev", "fal-ai/flux/schnell",
"fal-ai/flux-realism", "fal-ai/fast-sdxl",
"fal-ai/stable-diffusion-v35-large", "fal-ai/recraft-v3",
]
requires_package = "fal_client"
max_prompt_length = 10000
def generate_image(self, prompt, negative_prompt="", width=1024, height=1024,
seed=None, model=None, **kwargs):
fal_client = _require_import("fal_client", "fal-client")
os.environ["FAL_KEY"] = self.api_key
model_id = model or self.default_model
safe_prompt = self._truncate(prompt)
arguments = {"prompt": safe_prompt, "image_size": {"width": width, "height": height}, "num_images": 1}
if negative_prompt:
arguments["negative_prompt"] = smart_truncate(negative_prompt, 5000)
if seed is not None:
arguments["seed"] = seed
result = fal_client.subscribe(model_id, arguments=arguments)
images = result.get("images", [])
if images:
return self._url_to_image(images[0]["url"])
raise ValueError("No image returned from Fal.ai")
def img2img(self, prompt, image, strength=0.4, negative_prompt="",
seed=None, model=None, **kwargs):
fal_client = _require_import("fal_client", "fal-client")
os.environ["FAL_KEY"] = self.api_key
safe_prompt = self._truncate(prompt)
b64 = self._image_to_base64(image)
data_uri = "data:image/png;base64," + b64
arguments = {"prompt": safe_prompt, "image_url": data_uri, "strength": strength, "num_images": 1}
if negative_prompt:
arguments["negative_prompt"] = smart_truncate(negative_prompt, 5000)
if seed is not None:
arguments["seed"] = seed
model_id = model or "fal-ai/flux/dev/image-to-image"
result = fal_client.subscribe(model_id, arguments=arguments)
images = result.get("images", [])
if images:
return self._url_to_image(images[0]["url"])
raise ValueError("No image from Fal.ai img2img")
# ============================================
# 6. GOOGLE GEMINI
# ============================================
class GoogleGeminiProvider(BaseProvider):
name = "google"
display_name = "Google Gemini (Imagen 3)"
website = "https://aistudio.google.com/apikey"
supports_img2img = False
supports_negative_prompt = True
default_model = "imagen-3.0-generate-002"
available_models = ["imagen-3.0-generate-002", "imagen-3.0-fast-generate-001"]
requires_package = "google.generativeai"
max_prompt_length = 5000
def generate_image(self, prompt, negative_prompt="", width=1024, height=1024,
seed=None, model=None, **kwargs):
genai = _require_import("google.generativeai", "google-generativeai")
genai.configure(api_key=self.api_key)
model_id = model or self.default_model
safe_prompt = self._truncate(prompt)
imagen = genai.ImageGenerationModel(model_id)
params = dict(prompt=safe_prompt, number_of_images=1)
if negative_prompt:
params["negative_prompt"] = smart_truncate(negative_prompt, 2000)
response = imagen.generate_images(**params)
if response.images:
return response.images[0]._pil_image.convert("RGB")
raise ValueError("No image returned from Imagen")
# ============================================
# 7. HUGGING FACE INFERENCE API
# ============================================
class HuggingFaceProvider(BaseProvider):
name = "huggingface"
display_name = "HuggingFace Inference API"
website = "https://huggingface.co/settings/tokens"
supports_img2img = False
supports_negative_prompt = True
default_model = "black-forest-labs/FLUX.1-schnell"
available_models = [
"black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-dev",
"stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/stable-diffusion-3.5-large",
"runwayml/stable-diffusion-v1-5",
]
requires_package = ""
max_prompt_length = 10000
API_BASE = "https://api-inference.huggingface.co/models"
def generate_image(self, prompt, negative_prompt="", width=1024, height=1024,
seed=None, model=None, **kwargs):
model_id = model or self.default_model
url = self.API_BASE + "/" + model_id
headers = {"Authorization": "Bearer " + self.api_key}
safe_prompt = self._truncate(prompt)
payload = {"inputs": safe_prompt, "parameters": {"width": width, "height": height}}
if negative_prompt:
payload["parameters"]["negative_prompt"] = smart_truncate(negative_prompt, 5000)
if seed is not None:
payload["parameters"]["seed"] = seed
resp = requests.post(url, headers=headers, json=payload, timeout=180)
if resp.status_code == 503:
time.sleep(20)
resp = requests.post(url, headers=headers, json=payload, timeout=180)
resp.raise_for_status()
return self._bytes_to_image(resp.content)
# ============================================
# 8. xAI GROK
# ============================================
class XAIProvider(BaseProvider):
name = "xai"
display_name = "xAI Grok (Aurora)"
website = "https://console.x.ai/team/default/api-keys"
supports_img2img = False
supports_negative_prompt = False
default_model = "grok-2-image"
available_models = ["grok-2-image"]
requires_package = "openai"
max_prompt_length = 4000
def generate_image(self, prompt, negative_prompt="", width=1024, height=1024,
seed=None, model=None, **kwargs):
openai_mod = _require_import("openai")
client = openai_mod.OpenAI(api_key=self.api_key, base_url="https://api.x.ai/v1")
safe_prompt = self._truncate(prompt)
response = client.images.generate(
model=model or self.default_model, prompt=safe_prompt,
n=1, response_format="b64_json", size="1024x1024",
)
return self._base64_to_image(response.data[0].b64_json)
# ============================================
# 9. FIREWORKS AI
# ============================================
class FireworksProvider(BaseProvider):
name = "fireworks"
display_name = "Fireworks AI (Flux / SD)"
website = "https://fireworks.ai/account/api-keys"
supports_img2img = False
supports_negative_prompt = True
default_model = "accounts/fireworks/models/flux-1-1-pro"
available_models = [
"accounts/fireworks/models/flux-1-1-pro",
"accounts/fireworks/models/flux-1-schnell-fp8",
"accounts/fireworks/models/flux-1-dev-fp8",
"accounts/fireworks/models/stable-diffusion-xl-1024-v1-0",
]
requires_package = ""
max_prompt_length = 10000
def generate_image(self, prompt, negative_prompt="", width=1024, height=1024,
seed=None, model=None, **kwargs):
url = "https://api.fireworks.ai/inference/v1/images/generations"
headers = {"Authorization": "Bearer " + self.api_key, "Content-Type": "application/json"}
safe_prompt = self._truncate(prompt)
payload = {
"model": model or self.default_model, "prompt": safe_prompt,
"n": 1, "size": str(width) + "x" + str(height), "response_format": "b64_json",
}
if negative_prompt:
payload["negative_prompt"] = smart_truncate(negative_prompt, 5000)
if seed is not None:
payload["seed"] = seed
resp = requests.post(url, headers=headers, json=payload, timeout=120)
resp.raise_for_status()
data = resp.json()
return self._base64_to_image(data["data"][0]["b64_json"])
# ============================================
# 10. IDEOGRAM
# ============================================
class IdeogramProvider(BaseProvider):
name = "ideogram"
display_name = "Ideogram (v2 / v2-turbo)"
website = "https://ideogram.ai/manage-api"
supports_img2img = False
supports_negative_prompt = True
default_model = "V_2"
available_models = ["V_2", "V_2_TURBO", "V_1", "V_1_TURBO"]
requires_package = ""
max_prompt_length = 10000
def generate_image(self, prompt, negative_prompt="", width=1024, height=1024,
seed=None, model=None, **kwargs):
url = "https://api.ideogram.ai/generate"
headers = {"Api-Key": self.api_key, "Content-Type": "application/json"}
safe_prompt = self._truncate(prompt)
payload = {
"image_request": {
"prompt": safe_prompt, "model": model or self.default_model,
"magic_prompt_option": "AUTO", "aspect_ratio": "ASPECT_1_1",
}
}
if negative_prompt:
payload["image_request"]["negative_prompt"] = smart_truncate(negative_prompt, 5000)
if seed is not None:
payload["image_request"]["seed"] = seed
resp = requests.post(url, headers=headers, json=payload, timeout=120)
resp.raise_for_status()
data = resp.json()
return self._url_to_image(data["data"][0]["url"])
# ============================================
# 11. LEONARDO AI
# ============================================
class LeonardoProvider(BaseProvider):
name = "leonardo"
display_name = "Leonardo AI"
website = "https://app.leonardo.ai/api-access"
supports_img2img = False
supports_negative_prompt = True
default_model = "6b645e3a-d64f-4341-a6d8-7a3690fbf042"
available_models = [
"6b645e3a-d64f-4341-a6d8-7a3690fbf042",
"aa77f04e-3eec-4034-9c07-d0f619684628",
"1e60896f-3c26-4296-8ecc-53e2afecc132",
]
requires_package = ""
max_prompt_length = 10000
API_BASE = "https://cloud.leonardo.ai/api/rest/v1"
def generate_image(self, prompt, negative_prompt="", width=1024, height=1024,
seed=None, model=None, **kwargs):
headers = {"Authorization": "Bearer " + self.api_key, "Content-Type": "application/json"}
safe_prompt = self._truncate(prompt)
payload = {
"prompt": safe_prompt, "modelId": model or self.default_model,
"width": width, "height": height, "num_images": 1,
}
if negative_prompt:
payload["negative_prompt"] = smart_truncate(negative_prompt, 5000)
if seed is not None:
payload["seed"] = seed
resp = requests.post(self.API_BASE + "/generations", headers=headers, json=payload, timeout=60)
resp.raise_for_status()
gen_id = resp.json()["sdGenerationJob"]["generationId"]
for _ in range(30):
time.sleep(5)
poll = requests.get(self.API_BASE + "/generations/" + gen_id, headers=headers, timeout=30)
poll.raise_for_status()
gen = poll.json().get("generations_by_pk", {})
if gen.get("status") == "COMPLETE":
images = gen.get("generated_images", [])
if images:
return self._url_to_image(images[0]["url"])
raise TimeoutError("Leonardo generation timed out")
# ============================================
# 12. CUSTOM OPENAI-COMPATIBLE
# ============================================
class CustomOpenAIProvider(BaseProvider):
name = "custom_openai"
display_name = "Custom OpenAI-Compatible API"
website = ""
supports_img2img = False
supports_negative_prompt = False
default_model = "dall-e-3"
available_models = ["dall-e-3", "dall-e-2", "custom"]
requires_package = "openai"
max_prompt_length = 3900
def __init__(self, api_key="", base_url=""):
super().__init__(api_key)
self.base_url = base_url.strip().rstrip("/") if base_url else ""
def generate_image(self, prompt, negative_prompt="", width=1024, height=1024,
seed=None, model=None, **kwargs):
openai_mod = _require_import("openai")
ck = {"api_key": self.api_key}
if self.base_url:
ck["base_url"] = self.base_url
client = openai_mod.OpenAI(**ck)
safe_prompt = self._truncate(prompt)
response = client.images.generate(
model=model or self.default_model, prompt=safe_prompt,
n=1, size=str(width) + "x" + str(height), response_format="b64_json",
)
return self._base64_to_image(response.data[0].b64_json)
# ============================================
# 13. DIRECT URL API
# ============================================
class DirectURLProvider(BaseProvider):
name = "direct_url"
display_name = "Direct URL API (Any REST Endpoint)"
website = ""
supports_img2img = False
supports_negative_prompt = True
default_model = "custom"
available_models = ["custom"]
requires_package = ""
max_prompt_length = 50000
def __init__(self, api_key="", endpoint_url=""):
super().__init__(api_key)
self.endpoint_url = endpoint_url.strip()
def generate_image(self, prompt, negative_prompt="", width=1024, height=1024,
seed=None, model=None, **kwargs):
if not self.endpoint_url:
raise ValueError("No endpoint URL provided")
headers = {"Authorization": "Bearer " + self.api_key, "Content-Type": "application/json"}
safe_prompt = self._truncate(prompt)
payload = {"prompt": safe_prompt, "width": width, "height": height}
if negative_prompt:
payload["negative_prompt"] = smart_truncate(negative_prompt, 10000)
if seed is not None:
payload["seed"] = seed
if model and model != "custom":
payload["model"] = model
resp = requests.post(self.endpoint_url, headers=headers, json=payload, timeout=180)
resp.raise_for_status()
ct = resp.headers.get("Content-Type", "")
if "image" in ct:
return self._bytes_to_image(resp.content)
data = resp.json()
for key in ["images", "data", "output", "result"]:
if key in data:
item = data[key]
if isinstance(item, list):
item = item[0]
if isinstance(item, dict):
for subkey in ["b64_json", "url", "image"]:
if subkey in item:
val = item[subkey]
if isinstance(val, str) and val.startswith("http"):
return self._url_to_image(val)
return self._base64_to_image(val)
if isinstance(item, str):
if item.startswith("http"):
return self._url_to_image(item)
return self._base64_to_image(item)
raise ValueError("Could not parse image from API response")
# ============================================
# PROVIDER REGISTRY
# ============================================
PROVIDERS = {
"openai": OpenAIProvider,
"stability": StabilityProvider,
"replicate": ReplicateProvider,
"together": TogetherProvider,
"fal": FalProvider,
"google": GoogleGeminiProvider,
"huggingface": HuggingFaceProvider,
"xai": XAIProvider,
"fireworks": FireworksProvider,
"ideogram": IdeogramProvider,
"leonardo": LeonardoProvider,
"custom_openai": CustomOpenAIProvider,
"direct_url": DirectURLProvider,
}
PROVIDER_DISPLAY_NAMES = {cls.display_name: key for key, cls in PROVIDERS.items()}
def get_provider(provider_name, api_key, **kwargs):
cls = PROVIDERS.get(provider_name)
if cls is None:
raise ValueError("Unknown provider: " + str(provider_name))
if provider_name == "custom_openai":
return cls(api_key=api_key, base_url=kwargs.get("base_url", ""))
if provider_name == "direct_url":
return cls(api_key=api_key, endpoint_url=kwargs.get("endpoint_url", ""))
return cls(api_key=api_key)
def get_provider_info():
info = []
for key, cls in PROVIDERS.items():
pkg = cls.requires_package
installed = True
if pkg:
installed = _safe_import(pkg.split(".")[0]) is not None
info.append({
"name": key, "display_name": cls.display_name,
"website": cls.website, "supports_img2img": cls.supports_img2img,
"default_model": cls.default_model, "available_models": cls.available_models,
"requires_package": pkg, "package_installed": installed,
"max_prompt_length": cls.max_prompt_length,
})
return info
def get_models_for_provider(provider_name):
cls = PROVIDERS.get(provider_name)
return cls.available_models if cls else []