StablecogAPI / app.py
MySafeCode's picture
Update app.py
f3670b8 verified
import os
import json
import requests
import gradio as gr
from datetime import datetime
# API Configuration
API_KEY = os.getenv('StableCogKey')
API_HOST = 'https://api.stablecog.com'
headers = {
'Authorization': f'Bearer {API_KEY}',
'Content-Type': 'application/json'
}
# Global state for outputs pagination
current_outputs = []
current_page = 0
page_size = 10
# ========== MODEL AND SCHEDULER MANAGEMENT ==========
def get_model_list():
"""Get list of models for dropdown"""
try:
url = f'{API_HOST}/v1/image/generation/models'
response = requests.get(url, headers=headers, timeout=10)
if response.status_code == 200:
models = response.json().get('models', [])
# Sort: default first, then by name
models.sort(key=lambda x: (not x.get('is_default', False), x.get('name', '')))
return [(f"{m['name']}", m['id']) for m in models]
except Exception as e:
print(f"Error fetching models: {e}")
# Fallback to SDXL
return [("SDXL 1.0", "22b0857d-7edc-4d00-9cd9-45aa509db093")]
def get_scheduler_list():
"""Get list of schedulers"""
return [
("Euler", "b7224e56-1440-43b9-ac86-66d66f9e8c91"),
("Euler A", "6fb13a76-990d-49df-a2ab-7d9d22c33e3d"),
("DDIM", "c5a0bad3-bd9d-4c5c-96e3-9d8e8c0c7a6b"),
("DPMSolver++", "e9c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"),
("DPM++ 2M Karras", "f1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"),
("DPM++ SDE", "a1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"),
("DPM++ SDE Karras", "b1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"),
("Heun", "c1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"),
("LMS", "d1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"),
("LMS Karras", "e1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f")
]
# Initialize model and scheduler lists
MODELS = get_model_list()
SCHEDULERS = get_scheduler_list()
# Create mapping dictionaries
MODEL_MAP = {name: id for name, id in MODELS}
SCHEDULER_MAP = {name: id for name, id in SCHEDULERS}
# Default values
DEFAULT_MODEL = MODELS[0][0] if MODELS else "SDXL 1.0"
DEFAULT_SCHEDULER = SCHEDULERS[0][0] if SCHEDULERS else "Euler"
# ========== MODELS TAB ==========
def get_models():
"""Fetch and display available models"""
try:
url = f'{API_HOST}/v1/image/generation/models'
response = requests.get(url, headers=headers, timeout=10)
if response.status_code == 200:
data = response.json()
models = data.get('models', [])
# Format display
display_text = f"📊 Found {len(models)} models\n"
display_text += f"⏰ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
for i, model in enumerate(models, 1):
name = model.get('name', 'Unknown')
model_type = model.get('type', 'unknown')
description = model.get('description', 'No description')
display_text += f"{i}. 🔹 **{name}**\n"
display_text += f" 📝 {description}\n"
display_text += f" 🏷️ Type: {model_type}\n"
display_text += f" 🌍 Public: {'✅' if model.get('is_public') else '❌'}\n"
display_text += f" ⭐ Default: {'✅' if model.get('is_default') else '❌'}\n"
display_text += f" 👥 Community: {'✅' if model.get('is_community') else '❌'}\n"
display_text += "─" * 40 + "\n"
return display_text, str(data)
else:
return f"❌ Error {response.status_code}", f"Error: {response.text}"
except Exception as e:
return f"❌ Error: {str(e)}", "No data"
# ========== GENERATE TAB ==========
def generate_image(prompt, negative_prompt, model_name, width, height,
num_outputs, guidance_scale, inference_steps,
scheduler_name, seed, init_image_url, prompt_strength):
"""Generate images using StableCog API"""
try:
url = f'{API_HOST}/v1/image/generation/create'
# Get actual IDs from maps
model_id = MODEL_MAP.get(model_name, MODELS[0][1] if MODELS else "22b0857d-7edc-4d00-9cd9-45aa509db093")
scheduler_id = SCHEDULER_MAP.get(scheduler_name, "b7224e56-1440-43b9-ac86-66d66f9e8c91")
# Prepare request data
data = {
"prompt": prompt,
"model_id": model_id,
"width": int(width),
"height": int(height),
"num_outputs": int(num_outputs),
"guidance_scale": float(guidance_scale),
"inference_steps": int(inference_steps),
"scheduler_id": scheduler_id,
}
# Add optional fields if provided
if negative_prompt and negative_prompt.strip():
data["negative_prompt"] = negative_prompt.strip()
if seed and seed.strip():
try:
data["seed"] = int(seed.strip())
except:
pass
if init_image_url and init_image_url.strip():
data["init_image_url"] = init_image_url.strip()
if prompt_strength is not None:
data["prompt_strength"] = float(prompt_strength)
# Debug: Print the data being sent
print(f"Sending data: {json.dumps(data, indent=2)}")
# Make API request
response = requests.post(
url,
json=data, # Use json parameter instead of data=json.dumps()
headers=headers,
timeout=30 # Longer timeout for generation
)
print(f"Response status: {response.status_code}")
print(f"Response text: {response.text[:200]}...")
if response.status_code == 200:
result = response.json()
outputs = result.get('outputs', [])
remaining_credits = result.get('remaining_credits', 0)
settings = result.get('settings', {})
# Format response
display_text = f"✅ Generation successful!\n"
display_text += f"🪙 Remaining credits: {remaining_credits}\n"
display_text += f"🖼️ Generated {len(outputs)} image(s)\n"
display_text += f"⏰ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
# Get image URLs for gallery
image_urls = []
for i, output in enumerate(outputs, 1):
output_id = output.get('id', 'N/A')
image_url = output.get('url', '')
display_text += f"{i}. Output ID: {output_id}\n"
display_text += f" 📷 URL: {image_url[:60]}...\n"
if image_url:
image_urls.append(image_url)
# Create gallery HTML
gallery_html = ""
if image_urls:
gallery_html = create_gallery_html(image_urls, "Generated Images")
# Show settings used
display_text += "\n⚙️ Settings used:\n"
for key, value in settings.items():
display_text += f" {key}: {value}\n"
return display_text, gallery_html, str(result)
else:
error_msg = f"❌ Generation failed: {response.status_code}"
try:
error_detail = response.json()
error_msg += f"\nDetails: {json.dumps(error_detail, indent=2)}"
except:
error_msg += f"\nResponse: {response.text}"
return error_msg, "", error_msg
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
return error_msg, "", error_msg
# ========== OUTPUTS TAB ==========
def create_gallery_html(image_urls, title="Gallery"):
"""Create HTML gallery with lightbox"""
html = f"""
<style>
.image-gallery {{
display: grid;
grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
gap: 15px;
margin-bottom: 20px;
}}
.image-card {{
border-radius: 10px;
overflow: hidden;
background: rgba(255,255,255,0.1);
padding: 10px;
cursor: pointer;
transition: transform 0.2s;
}}
.image-card:hover {{
transform: scale(1.02);
background: rgba(255,255,255,0.15);
}}
.image-card img {{
width: 100%;
height: 200px;
object-fit: cover;
border-radius: 8px;
}}
.image-meta {{
margin-top: 8px;
font-size: 12px;
line-height: 1.4;
}}
.lightbox {{
display: none;
position: fixed;
z-index: 9999;
left: 0;
top: 0;
width: 100%;
height: 100%;
background: rgba(0,0,0,0.9);
justify-content: center;
align-items: center;
}}
.lightbox img {{
max-width: 90%;
max-height: 90%;
border-radius: 10px;
box-shadow: 0 20px 60px rgba(0,0,0,0.5);
}}
.lightbox.active {{
display: flex;
}}
.close-btn {{
position: absolute;
top: 20px;
right: 30px;
color: white;
font-size: 40px;
font-weight: bold;
cursor: pointer;
z-index: 10000;
}}
.pagination {{
display: flex;
justify-content: center;
gap: 10px;
margin-top: 20px;
}}
.page-btn {{
padding: 8px 16px;
background: rgba(255,255,255,0.1);
border: 1px solid rgba(255,255,255,0.2);
border-radius: 8px;
color: white;
cursor: pointer;
transition: background 0.2s;
}}
.page-btn:hover {{
background: rgba(255,255,255,0.2);
}}
.page-btn.disabled {{
opacity: 0.5;
cursor: not-allowed;
}}
.page-info {{
display: flex;
align-items: center;
padding: 8px 16px;
color: white;
}}
</style>
<div class="lightbox" id="lightbox" onclick="closeLightbox()">
<span class="close-btn" onclick="closeLightbox()">&times;</span>
<img id="lightbox-img" onclick="event.stopPropagation()">
</div>
<script>
function openLightbox(imgSrc) {{
document.getElementById('lightbox-img').src = imgSrc;
document.getElementById('lightbox').classList.add('active');
}}
function closeLightbox() {{
document.getElementById('lightbox').classList.remove('active');
}}
// Close lightbox on ESC key
document.addEventListener('keydown', function(e) {{
if (e.key === 'Escape') closeLightbox();
}});
</script>
<h3>{title}</h3>
<div class="image-gallery">
"""
for i, image_url in enumerate(image_urls):
html += f"""
<div class="image-card" onclick="openLightbox('{image_url}')">
<img src="{image_url}" alt="Image {i+1}">
<div class="image-meta">
<div>Image #{i+1}</div>
</div>
</div>
"""
html += "</div>"
return html
def fetch_outputs():
"""Fetch outputs from API"""
global current_outputs
try:
url = f'{API_HOST}/v1/image/generation/outputs'
response = requests.get(url, headers=headers, timeout=10)
if response.status_code == 200:
data = response.json()
current_outputs = data.get('outputs', [])
return data
else:
return None
except:
return None
def update_outputs_display():
"""Update outputs display with current page"""
global current_outputs, current_page, page_size
if not current_outputs:
return "📭 No outputs found. Generate some images first!", "<div style='text-align: center; padding: 40px; color: #888;'>No images found</div>", "[]"
total = len(current_outputs)
total_pages = (total + page_size - 1) // page_size # Ceiling division
# Calculate page bounds
start_idx = current_page * page_size
end_idx = min(start_idx + page_size, total)
page_outputs = current_outputs[start_idx:end_idx]
# Format display
display_text = f"📊 Total outputs: {total}\n"
display_text += f"📄 Page {current_page + 1} of {total_pages}\n"
display_text += f"🖼️ Showing {start_idx + 1}-{end_idx} of {total}\n"
display_text += f"⏰ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
# Get image URLs for gallery
image_urls = []
for idx, output in enumerate(page_outputs, start=start_idx + 1):
output_id = output.get('id', 'N/A')
created_at = output.get('created_at', 'N/A')
model_name = output.get('model_name', 'Unknown')
# Gallery status
gallery_status = output.get('gallery_status', 'not_submitted')
gallery_emoji = {
'not_submitted': '🔒',
'submitted': '📤',
'approved': '✅',
'rejected': '❌'
}.get(gallery_status, '❓')
# Favorites
is_favorited = output.get('is_favorited', False)
favorite_emoji = '❤️' if is_favorited else '🤍'
# Format timestamp
if created_at != 'N/A':
try:
dt = datetime.fromisoformat(created_at.replace('Z', '+00:00'))
created_date = dt.strftime('%Y-%m-%d')
except:
created_date = created_at
else:
created_date = 'Unknown'
# Get image URL
image_url = output.get('image_url')
if not image_url:
image_urls_list = output.get('image_urls', [])
if image_urls_list and isinstance(image_urls_list, list) and len(image_urls_list) > 0:
image_url = image_urls_list[0]
if image_url:
image_urls.append(image_url)
display_text += f"{idx}. 🔹 **Output {output_id[:8]}...**\n"
display_text += f" 🕒 {created_at}\n"
display_text += f" 🤖 {model_name}\n"
display_text += f" 🖼️ {gallery_emoji} {gallery_status}\n"
display_text += f" {favorite_emoji} Favorite\n"
# Show generation details if available
generation = output.get('generation', {})
if generation:
prompt = generation.get('prompt', 'No prompt')
if len(prompt) > 50:
prompt = prompt[:50] + '...'
display_text += f" 📝 {prompt}\n"
else:
display_text += f"{idx}. ⚠️ No image data\n"
display_text += "─" * 40 + "\n"
# Create gallery HTML
gallery_html = create_gallery_html(image_urls, "Your Generated Images")
# Add pagination controls
gallery_html += f"""
<div class="pagination">
<button class="page-btn {'disabled' if current_page == 0 else ''}"
onclick="{'return false;' if current_page == 0 else f'window.paginationButtonClick({current_page - 1})'}">
◀ Previous
</button>
<div class="page-info">
Page {current_page + 1} of {total_pages}
</div>
<button class="page-btn {'disabled' if current_page >= total_pages - 1 else ''}"
onclick="{'return false;' if current_page >= total_pages - 1 else f'window.paginationButtonClick({current_page + 1})'}">
Next ▶
</button>
</div>
<script>
window.paginationButtonClick = function(page) {{
const event = new CustomEvent('gradio_pagination', {{ detail: {{ page: page }} }});
document.dispatchEvent(event);
}}
</script>
"""
return display_text, gallery_html, str(current_outputs)
def load_outputs():
"""Load outputs from API and display first page"""
global current_page
current_page = 0
data = fetch_outputs()
if data:
return update_outputs_display()
else:
return "❌ Failed to load outputs", "<div style='color: red; padding: 20px;'>Failed to load outputs</div>", "[]"
def next_page():
"""Go to next page"""
global current_page
if current_outputs:
total_pages = (len(current_outputs) + page_size - 1) // page_size
if current_page < total_pages - 1:
current_page += 1
return update_outputs_display()
def prev_page():
"""Go to previous page"""
global current_page
if current_page > 0:
current_page -= 1
return update_outputs_display()
# ========== CREATE INTERFACE ==========
with gr.Blocks(title="StableCog Dashboard", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎨 StableCog Image Generator")
with gr.Tabs():
# ========== GENERATE TAB ==========
with gr.Tab("✨ Generate", id="generate"):
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### 🎯 Prompt Settings")
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the image you want to generate...",
lines=3
)
negative_prompt = gr.Textbox(
label="Negative Prompt (Optional)",
placeholder="What to avoid in the image...",
lines=2
)
init_image_url = gr.Textbox(
label="Init Image URL (Optional)",
placeholder="https://example.com/image.jpg",
lines=1
)
prompt_strength = gr.Slider(
label="Prompt Strength (for img2img)",
minimum=0.0,
maximum=1.0,
value=0.8,
step=0.1,
visible=False
)
# Show prompt strength only when init image is provided
def toggle_prompt_strength(url):
return gr.update(visible=bool(url and url.strip()))
init_image_url.change(
toggle_prompt_strength,
inputs=[init_image_url],
outputs=[prompt_strength]
)
with gr.Group():
gr.Markdown("### ⚙️ Generation Settings")
with gr.Row():
model_dropdown = gr.Dropdown(
label="Model",
choices=[m[0] for m in MODELS],
value=DEFAULT_MODEL
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=1024,
value=768,
step=8
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=1024,
value=768,
step=8
)
num_outputs = gr.Slider(
label="Number of Images",
minimum=1,
maximum=4,
value=1,
step=1
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1.0,
maximum=20.0,
value=7.0,
step=0.5
)
inference_steps = gr.Slider(
label="Inference Steps",
minimum=10,
maximum=50,
value=30,
step=1
)
with gr.Row():
scheduler_dropdown = gr.Dropdown(
label="Scheduler",
choices=[s[0] for s in SCHEDULERS],
value=DEFAULT_SCHEDULER
)
seed = gr.Textbox(
label="Seed (Optional)",
placeholder="Leave empty for random",
lines=1
)
with gr.Column(scale=1):
generate_btn = gr.Button(
"🚀 Generate Image",
variant="primary",
size="lg"
)
with gr.Group():
gr.Markdown("### 📊 Results")
generate_output = gr.Textbox(
label="Generation Result",
lines=15,
interactive=False
)
with gr.Group():
generate_gallery = gr.HTML(label="Generated Images")
with gr.Group():
generate_raw = gr.Code(
label="Raw API Response",
language="json",
lines=10,
interactive=False
)
# Connect generate button
generate_btn.click(
generate_image,
inputs=[
prompt, negative_prompt, model_dropdown, width, height,
num_outputs, guidance_scale, inference_steps,
scheduler_dropdown, seed, init_image_url, prompt_strength
],
outputs=[generate_output, generate_gallery, generate_raw]
)
# ========== OUTPUTS TAB ==========
with gr.Tab("🖼️ My Outputs", id="outputs"):
with gr.Row():
with gr.Column(scale=1):
outputs_display = gr.Textbox(label="Output Details", lines=25)
with gr.Column(scale=2):
outputs_gallery = gr.HTML(label="Image Gallery")
with gr.Row():
outputs_raw = gr.Code(label="Raw JSON", language="json", lines=10)
with gr.Row():
load_outputs_btn = gr.Button("🔄 Load My Outputs", variant="primary")
prev_page_btn = gr.Button("◀ Previous Page")
next_page_btn = gr.Button("Next Page ▶")
# Add pagination JavaScript
js = """
<script>
window.paginationButtonClick = function(page) {
const event = new CustomEvent('gradio_pagination', { detail: { page: page } });
document.dispatchEvent(event);
}
</script>
"""
gr.HTML(js)
# Connect buttons
load_outputs_btn.click(
load_outputs,
outputs=[outputs_display, outputs_gallery, outputs_raw]
)
prev_page_btn.click(
prev_page,
outputs=[outputs_display, outputs_gallery, outputs_raw]
)
next_page_btn.click(
next_page,
outputs=[outputs_display, outputs_gallery, outputs_raw]
)
# ========== MODELS TAB ==========
with gr.Tab("🤖 Models", id="models"):
with gr.Row():
models_display = gr.Textbox(label="Available Models", lines=25)
models_raw = gr.Code(label="Raw JSON", language="json", lines=25)
check_models_btn = gr.Button("🔄 Refresh Models", variant="primary")
check_models_btn.click(get_models, outputs=[models_display, models_raw])
if __name__ == "__main__":
demo.launch()