import os
import json
import requests
import gradio as gr
from datetime import datetime
# API Configuration
API_KEY = os.getenv('StableCogKey')
API_HOST = 'https://api.stablecog.com'
headers = {
'Authorization': f'Bearer {API_KEY}',
'Content-Type': 'application/json'
}
# Global state for outputs pagination
current_outputs = []
current_page = 0
page_size = 10
# ========== MODEL AND SCHEDULER MANAGEMENT ==========
def get_model_list():
"""Get list of models for dropdown"""
try:
url = f'{API_HOST}/v1/image/generation/models'
response = requests.get(url, headers=headers, timeout=10)
if response.status_code == 200:
models = response.json().get('models', [])
# Sort: default first, then by name
models.sort(key=lambda x: (not x.get('is_default', False), x.get('name', '')))
return [(f"{m['name']}", m['id']) for m in models]
except Exception as e:
print(f"Error fetching models: {e}")
# Fallback to SDXL
return [("SDXL 1.0", "22b0857d-7edc-4d00-9cd9-45aa509db093")]
def get_scheduler_list():
"""Get list of schedulers"""
return [
("Euler", "b7224e56-1440-43b9-ac86-66d66f9e8c91"),
("Euler A", "6fb13a76-990d-49df-a2ab-7d9d22c33e3d"),
("DDIM", "c5a0bad3-bd9d-4c5c-96e3-9d8e8c0c7a6b"),
("DPMSolver++", "e9c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"),
("DPM++ 2M Karras", "f1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"),
("DPM++ SDE", "a1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"),
("DPM++ SDE Karras", "b1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"),
("Heun", "c1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"),
("LMS", "d1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"),
("LMS Karras", "e1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f")
]
# Initialize model and scheduler lists
MODELS = get_model_list()
SCHEDULERS = get_scheduler_list()
# Create mapping dictionaries
MODEL_MAP = {name: id for name, id in MODELS}
SCHEDULER_MAP = {name: id for name, id in SCHEDULERS}
# Default values
DEFAULT_MODEL = MODELS[0][0] if MODELS else "SDXL 1.0"
DEFAULT_SCHEDULER = SCHEDULERS[0][0] if SCHEDULERS else "Euler"
# ========== MODELS TAB ==========
def get_models():
"""Fetch and display available models"""
try:
url = f'{API_HOST}/v1/image/generation/models'
response = requests.get(url, headers=headers, timeout=10)
if response.status_code == 200:
data = response.json()
models = data.get('models', [])
# Format display
display_text = f"š Found {len(models)} models\n"
display_text += f"ā° {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
for i, model in enumerate(models, 1):
name = model.get('name', 'Unknown')
model_type = model.get('type', 'unknown')
description = model.get('description', 'No description')
display_text += f"{i}. š¹ **{name}**\n"
display_text += f" š {description}\n"
display_text += f" š·ļø Type: {model_type}\n"
display_text += f" š Public: {'ā
' if model.get('is_public') else 'ā'}\n"
display_text += f" ā Default: {'ā
' if model.get('is_default') else 'ā'}\n"
display_text += f" š„ Community: {'ā
' if model.get('is_community') else 'ā'}\n"
display_text += "ā" * 40 + "\n"
return display_text, str(data)
else:
return f"ā Error {response.status_code}", f"Error: {response.text}"
except Exception as e:
return f"ā Error: {str(e)}", "No data"
# ========== GENERATE TAB ==========
def generate_image(prompt, negative_prompt, model_name, width, height,
num_outputs, guidance_scale, inference_steps,
scheduler_name, seed, init_image_url, prompt_strength):
"""Generate images using StableCog API"""
try:
url = f'{API_HOST}/v1/image/generation/create'
# Get actual IDs from maps
model_id = MODEL_MAP.get(model_name, MODELS[0][1] if MODELS else "22b0857d-7edc-4d00-9cd9-45aa509db093")
scheduler_id = SCHEDULER_MAP.get(scheduler_name, "b7224e56-1440-43b9-ac86-66d66f9e8c91")
# Prepare request data
data = {
"prompt": prompt,
"model_id": model_id,
"width": int(width),
"height": int(height),
"num_outputs": int(num_outputs),
"guidance_scale": float(guidance_scale),
"inference_steps": int(inference_steps),
"scheduler_id": scheduler_id,
}
# Add optional fields if provided
if negative_prompt and negative_prompt.strip():
data["negative_prompt"] = negative_prompt.strip()
if seed and seed.strip():
try:
data["seed"] = int(seed.strip())
except:
pass
if init_image_url and init_image_url.strip():
data["init_image_url"] = init_image_url.strip()
if prompt_strength is not None:
data["prompt_strength"] = float(prompt_strength)
# Debug: Print the data being sent
print(f"Sending data: {json.dumps(data, indent=2)}")
# Make API request
response = requests.post(
url,
json=data, # Use json parameter instead of data=json.dumps()
headers=headers,
timeout=30 # Longer timeout for generation
)
print(f"Response status: {response.status_code}")
print(f"Response text: {response.text[:200]}...")
if response.status_code == 200:
result = response.json()
outputs = result.get('outputs', [])
remaining_credits = result.get('remaining_credits', 0)
settings = result.get('settings', {})
# Format response
display_text = f"ā
Generation successful!\n"
display_text += f"šŖ Remaining credits: {remaining_credits}\n"
display_text += f"š¼ļø Generated {len(outputs)} image(s)\n"
display_text += f"ā° {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
# Get image URLs for gallery
image_urls = []
for i, output in enumerate(outputs, 1):
output_id = output.get('id', 'N/A')
image_url = output.get('url', '')
display_text += f"{i}. Output ID: {output_id}\n"
display_text += f" š· URL: {image_url[:60]}...\n"
if image_url:
image_urls.append(image_url)
# Create gallery HTML
gallery_html = ""
if image_urls:
gallery_html = create_gallery_html(image_urls, "Generated Images")
# Show settings used
display_text += "\nāļø Settings used:\n"
for key, value in settings.items():
display_text += f" {key}: {value}\n"
return display_text, gallery_html, str(result)
else:
error_msg = f"ā Generation failed: {response.status_code}"
try:
error_detail = response.json()
error_msg += f"\nDetails: {json.dumps(error_detail, indent=2)}"
except:
error_msg += f"\nResponse: {response.text}"
return error_msg, "", error_msg
except Exception as e:
error_msg = f"ā Error: {str(e)}"
return error_msg, "", error_msg
# ========== OUTPUTS TAB ==========
def create_gallery_html(image_urls, title="Gallery"):
"""Create HTML gallery with lightbox"""
html = f"""
×
{title}
"""
for i, image_url in enumerate(image_urls):
html += f"""
"""
html += "
"
return html
def fetch_outputs():
"""Fetch outputs from API"""
global current_outputs
try:
url = f'{API_HOST}/v1/image/generation/outputs'
response = requests.get(url, headers=headers, timeout=10)
if response.status_code == 200:
data = response.json()
current_outputs = data.get('outputs', [])
return data
else:
return None
except:
return None
def update_outputs_display():
"""Update outputs display with current page"""
global current_outputs, current_page, page_size
if not current_outputs:
return "š No outputs found. Generate some images first!", "No images found
", "[]"
total = len(current_outputs)
total_pages = (total + page_size - 1) // page_size # Ceiling division
# Calculate page bounds
start_idx = current_page * page_size
end_idx = min(start_idx + page_size, total)
page_outputs = current_outputs[start_idx:end_idx]
# Format display
display_text = f"š Total outputs: {total}\n"
display_text += f"š Page {current_page + 1} of {total_pages}\n"
display_text += f"š¼ļø Showing {start_idx + 1}-{end_idx} of {total}\n"
display_text += f"ā° {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
# Get image URLs for gallery
image_urls = []
for idx, output in enumerate(page_outputs, start=start_idx + 1):
output_id = output.get('id', 'N/A')
created_at = output.get('created_at', 'N/A')
model_name = output.get('model_name', 'Unknown')
# Gallery status
gallery_status = output.get('gallery_status', 'not_submitted')
gallery_emoji = {
'not_submitted': 'š',
'submitted': 'š¤',
'approved': 'ā
',
'rejected': 'ā'
}.get(gallery_status, 'ā')
# Favorites
is_favorited = output.get('is_favorited', False)
favorite_emoji = 'ā¤ļø' if is_favorited else 'š¤'
# Format timestamp
if created_at != 'N/A':
try:
dt = datetime.fromisoformat(created_at.replace('Z', '+00:00'))
created_date = dt.strftime('%Y-%m-%d')
except:
created_date = created_at
else:
created_date = 'Unknown'
# Get image URL
image_url = output.get('image_url')
if not image_url:
image_urls_list = output.get('image_urls', [])
if image_urls_list and isinstance(image_urls_list, list) and len(image_urls_list) > 0:
image_url = image_urls_list[0]
if image_url:
image_urls.append(image_url)
display_text += f"{idx}. š¹ **Output {output_id[:8]}...**\n"
display_text += f" š {created_at}\n"
display_text += f" š¤ {model_name}\n"
display_text += f" š¼ļø {gallery_emoji} {gallery_status}\n"
display_text += f" {favorite_emoji} Favorite\n"
# Show generation details if available
generation = output.get('generation', {})
if generation:
prompt = generation.get('prompt', 'No prompt')
if len(prompt) > 50:
prompt = prompt[:50] + '...'
display_text += f" š {prompt}\n"
else:
display_text += f"{idx}. ā ļø No image data\n"
display_text += "ā" * 40 + "\n"
# Create gallery HTML
gallery_html = create_gallery_html(image_urls, "Your Generated Images")
# Add pagination controls
gallery_html += f"""
"""
return display_text, gallery_html, str(current_outputs)
def load_outputs():
"""Load outputs from API and display first page"""
global current_page
current_page = 0
data = fetch_outputs()
if data:
return update_outputs_display()
else:
return "ā Failed to load outputs", "Failed to load outputs
", "[]"
def next_page():
"""Go to next page"""
global current_page
if current_outputs:
total_pages = (len(current_outputs) + page_size - 1) // page_size
if current_page < total_pages - 1:
current_page += 1
return update_outputs_display()
def prev_page():
"""Go to previous page"""
global current_page
if current_page > 0:
current_page -= 1
return update_outputs_display()
# ========== CREATE INTERFACE ==========
with gr.Blocks(title="StableCog Dashboard", theme=gr.themes.Soft()) as demo:
gr.Markdown("# šØ StableCog Image Generator")
with gr.Tabs():
# ========== GENERATE TAB ==========
with gr.Tab("⨠Generate", id="generate"):
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### šÆ Prompt Settings")
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the image you want to generate...",
lines=3
)
negative_prompt = gr.Textbox(
label="Negative Prompt (Optional)",
placeholder="What to avoid in the image...",
lines=2
)
init_image_url = gr.Textbox(
label="Init Image URL (Optional)",
placeholder="https://example.com/image.jpg",
lines=1
)
prompt_strength = gr.Slider(
label="Prompt Strength (for img2img)",
minimum=0.0,
maximum=1.0,
value=0.8,
step=0.1,
visible=False
)
# Show prompt strength only when init image is provided
def toggle_prompt_strength(url):
return gr.update(visible=bool(url and url.strip()))
init_image_url.change(
toggle_prompt_strength,
inputs=[init_image_url],
outputs=[prompt_strength]
)
with gr.Group():
gr.Markdown("### āļø Generation Settings")
with gr.Row():
model_dropdown = gr.Dropdown(
label="Model",
choices=[m[0] for m in MODELS],
value=DEFAULT_MODEL
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=1024,
value=768,
step=8
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=1024,
value=768,
step=8
)
num_outputs = gr.Slider(
label="Number of Images",
minimum=1,
maximum=4,
value=1,
step=1
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1.0,
maximum=20.0,
value=7.0,
step=0.5
)
inference_steps = gr.Slider(
label="Inference Steps",
minimum=10,
maximum=50,
value=30,
step=1
)
with gr.Row():
scheduler_dropdown = gr.Dropdown(
label="Scheduler",
choices=[s[0] for s in SCHEDULERS],
value=DEFAULT_SCHEDULER
)
seed = gr.Textbox(
label="Seed (Optional)",
placeholder="Leave empty for random",
lines=1
)
with gr.Column(scale=1):
generate_btn = gr.Button(
"š Generate Image",
variant="primary",
size="lg"
)
with gr.Group():
gr.Markdown("### š Results")
generate_output = gr.Textbox(
label="Generation Result",
lines=15,
interactive=False
)
with gr.Group():
generate_gallery = gr.HTML(label="Generated Images")
with gr.Group():
generate_raw = gr.Code(
label="Raw API Response",
language="json",
lines=10,
interactive=False
)
# Connect generate button
generate_btn.click(
generate_image,
inputs=[
prompt, negative_prompt, model_dropdown, width, height,
num_outputs, guidance_scale, inference_steps,
scheduler_dropdown, seed, init_image_url, prompt_strength
],
outputs=[generate_output, generate_gallery, generate_raw]
)
# ========== OUTPUTS TAB ==========
with gr.Tab("š¼ļø My Outputs", id="outputs"):
with gr.Row():
with gr.Column(scale=1):
outputs_display = gr.Textbox(label="Output Details", lines=25)
with gr.Column(scale=2):
outputs_gallery = gr.HTML(label="Image Gallery")
with gr.Row():
outputs_raw = gr.Code(label="Raw JSON", language="json", lines=10)
with gr.Row():
load_outputs_btn = gr.Button("š Load My Outputs", variant="primary")
prev_page_btn = gr.Button("ā Previous Page")
next_page_btn = gr.Button("Next Page ā¶")
# Add pagination JavaScript
js = """
"""
gr.HTML(js)
# Connect buttons
load_outputs_btn.click(
load_outputs,
outputs=[outputs_display, outputs_gallery, outputs_raw]
)
prev_page_btn.click(
prev_page,
outputs=[outputs_display, outputs_gallery, outputs_raw]
)
next_page_btn.click(
next_page,
outputs=[outputs_display, outputs_gallery, outputs_raw]
)
# ========== MODELS TAB ==========
with gr.Tab("š¤ Models", id="models"):
with gr.Row():
models_display = gr.Textbox(label="Available Models", lines=25)
models_raw = gr.Code(label="Raw JSON", language="json", lines=25)
check_models_btn = gr.Button("š Refresh Models", variant="primary")
check_models_btn.click(get_models, outputs=[models_display, models_raw])
if __name__ == "__main__":
demo.launch()