import os import json import requests import gradio as gr from datetime import datetime import time # API Configuration API_KEY = os.getenv('StableCogKey') API_HOST = 'https://api.stablecog.com' headers = { 'Authorization': f'Bearer {API_KEY}', 'Content-Type': 'application/json' } # Global state for outputs pagination current_outputs = [] current_page = 0 page_size = 10 # ========== MODEL AND SCHEDULER MANAGEMENT ========== def get_model_list(): """Get list of models for dropdown""" try: url = f'{API_HOST}/v1/image/generation/models' response = requests.get(url, headers=headers, timeout=10) if response.status_code == 200: models = response.json().get('models', []) # Sort: default first, then by name models.sort(key=lambda x: (not x.get('is_default', False), x.get('name', ''))) return [(f"{m['name']}", m['id']) for m in models] except Exception as e: print(f"Error fetching models: {e}") # Fallback to a working model from your successful response return [ ("SDXL 1.0", "0a99668b-45bd-4f7e-aa9c-f9aaa41ef13b"), ("SDXL Lightning", "22b0857d-7edc-4d00-9cd9-45aa509db093") ] def get_scheduler_list(): """Get list of schedulers""" # From your successful response, use the actual scheduler IDs return [ ("Euler", "af2679a4-dbbb-4950-8c06-c3bb15416ef6"), ("Euler A", "6fb13a76-990d-49df-a2ab-7d9d22c33e3d"), ("DDIM", "c5a0bad3-bd9d-4c5c-96e3-9d8e8c0c7a6b"), ("DPMSolver++", "e9c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f"), ("DPM++ 2M Karras", "f1c3a7b3-2b5c-4a5d-8e2d-7b1c9a3d4e5f") ] # Initialize model and scheduler lists MODELS = get_model_list() SCHEDULERS = get_scheduler_list() # Create mapping dictionaries MODEL_MAP = {name: id for name, id in MODELS} SCHEDULER_MAP = {name: id for name, id in SCHEDULERS} # Default values DEFAULT_MODEL = MODELS[0][0] if MODELS else "SDXL 1.0" DEFAULT_SCHEDULER = SCHEDULERS[0][0] if SCHEDULERS else "Euler" # ========== CUSTOM DARK THEME CSS ========== CUSTOM_CSS = """ /* Dark theme background */ .gradio-container { background: linear-gradient(135deg, #0a0a2a 0%, #1a1a3a 100%) !important; min-height: 100vh !important; padding: 20px !important; } /* Main container */ .container { background: rgba(255, 255, 255, 0.05) !important; backdrop-filter: blur(10px) !important; border-radius: 20px !important; padding: 30px !important; border: 1px solid rgba(255, 255, 255, 0.1) !important; box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3) !important; max-width: 1400px !important; margin: 0 auto !important; } /* Tab styling */ .tabs { background: rgba(0, 0, 0, 0.3) !important; border-radius: 10px !important; padding: 5px !important; margin-bottom: 20px !important; } .tab-nav { background: rgba(255, 255, 255, 0.1) !important; border-radius: 8px !important; } .tab-button { background: transparent !important; color: #aaa !important; border: none !important; } .tab-button.selected { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; color: white !important; font-weight: bold !important; } /* Button styling */ button { border-radius: 8px !important; border: none !important; transition: all 0.2s !important; } button.primary { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; color: white !important; font-weight: bold !important; } button.secondary { background: rgba(255, 255, 255, 0.1) !important; color: white !important; border: 1px solid rgba(255, 255, 255, 0.2) !important; } button:hover { transform: translateY(-2px) !important; box-shadow: 0 5px 15px rgba(0, 0, 0, 0.3) !important; } /* Input styling */ input, textarea, .gradio-dropdown { background: rgba(255, 255, 255, 0.05) !important; border: 1px solid rgba(255, 255, 255, 0.1) !important; color: white !important; border-radius: 8px !important; } input:focus, textarea:focus, .gradio-dropdown:focus-within { border-color: #667eea !important; box-shadow: 0 0 0 2px rgba(102, 126, 234, 0.2) !important; outline: none !important; } /* Label styling */ label { color: #ddd !important; font-weight: 500 !important; } /* Textbox and code styling */ .gradio-textbox, .gradio-code { background: rgba(0, 0, 0, 0.3) !important; border: 1px solid rgba(255, 255, 255, 0.1) !important; color: #eee !important; } /* Slider styling */ .gr-slider > .range { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; } .gr-slider > .range-container { background: rgba(255, 255, 255, 0.1) !important; } /* Group styling */ .gr-group { background: rgba(255, 255, 255, 0.05) !important; border: 1px solid rgba(255, 255, 255, 0.1) !important; border-radius: 12px !important; padding: 20px !important; margin-bottom: 20px !important; } /* Markdown styling */ .gr-markdown h1, .gr-markdown h2, .gr-markdown h3 { color: white !important; } .gr-markdown p { color: #bbb !important; } /* Accordion styling */ .gr-accordion { background: rgba(255, 255, 255, 0.05) !important; border: 1px solid rgba(255, 255, 255, 0.1) !important; border-radius: 8px !important; } /* Custom scrollbar */ ::-webkit-scrollbar { width: 8px; height: 8px; } ::-webkit-scrollbar-track { background: rgba(255, 255, 255, 0.05); border-radius: 4px; } ::-webkit-scrollbar-thumb { background: rgba(102, 126, 234, 0.5); border-radius: 4px; } ::-webkit-scrollbar-thumb:hover { background: rgba(102, 126, 234, 0.8); } /* Header styling */ header { background: transparent !important; border-bottom: none !important; } /* Footer hiding */ footer { display: none !important; } """ # ========== MODELS TAB ========== def get_models(): """Fetch and display available models""" try: url = f'{API_HOST}/v1/image/generation/models' response = requests.get(url, headers=headers, timeout=10) if response.status_code == 200: data = response.json() models = data.get('models', []) # Format display display_text = f"š Found {len(models)} models\n" display_text += f"ā° {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" for i, model in enumerate(models, 1): name = model.get('name', 'Unknown') model_type = model.get('type', 'unknown') description = model.get('description', 'No description') display_text += f"{i}. š¹ **{name}**\n" display_text += f" š {description}\n" display_text += f" š·ļø Type: {model_type}\n" display_text += f" š Public: {'ā ' if model.get('is_public') else 'ā'}\n" display_text += f" ā Default: {'ā ' if model.get('is_default') else 'ā'}\n" display_text += f" š„ Community: {'ā ' if model.get('is_community') else 'ā'}\n" display_text += "ā" * 40 + "\n" return display_text, str(data) else: return f"ā Error {response.status_code}", f"Error: {response.text}" except Exception as e: return f"ā Error: {str(e)}", "No data" # ========== GENERATE TAB ========== def generate_image(prompt, negative_prompt, model_name, width, height, num_outputs, guidance_scale, inference_steps, scheduler_name, seed, init_image_url, prompt_strength): """Generate images using StableCog API""" try: url = f'{API_HOST}/v1/image/generation/create' # Get actual IDs from maps model_id = MODEL_MAP.get(model_name, MODELS[0][1] if MODELS else "0a99668b-45bd-4f7e-aa9c-f9aaa41ef13b") scheduler_id = SCHEDULER_MAP.get(scheduler_name, "af2679a4-dbbb-4950-8c06-c3bb15416ef6") # Prepare request data data = { "prompt": prompt, "model_id": model_id, "width": int(width), "height": int(height), "num_outputs": int(num_outputs), "guidance_scale": float(guidance_scale), "inference_steps": int(inference_steps), "scheduler_id": scheduler_id, } # Add optional fields if provided if negative_prompt and negative_prompt.strip(): data["negative_prompt"] = negative_prompt.strip() if seed and seed.strip(): try: data["seed"] = int(seed.strip()) except: # Generate random seed if invalid import random data["seed"] = random.randint(1, 1000000000) if init_image_url and init_image_url.strip(): data["init_image_url"] = init_image_url.strip() if prompt_strength is not None: data["prompt_strength"] = float(prompt_strength) # Show loading state start_time = time.time() # Make API request response = requests.post( url, json=data, headers=headers, timeout=60 # Longer timeout for generation ) generation_time = time.time() - start_time if response.status_code == 200: result = response.json() outputs = result.get('outputs', []) remaining_credits = result.get('remaining_credits', 0) settings = result.get('settings', {}) # Format response display_text = f"ā Generation successful!\n" display_text += f"šŖ Remaining credits: {remaining_credits}\n" display_text += f"š¼ļø Generated {len(outputs)} image(s)\n" display_text += f"ā±ļø Generation time: {generation_time:.1f}s\n" display_text += f"ā° {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" # Get image URLs for gallery image_urls = [] for i, output in enumerate(outputs, 1): output_id = output.get('id', 'N/A') image_url = output.get('url', '') display_text += f"{i}. š¹ Output ID: {output_id}\n" display_text += f" š· URL: {image_url}\n" if image_url: image_urls.append(image_url) # Create gallery HTML gallery_html = "" if image_urls: gallery_html = create_gallery_html(image_urls, "šØ Generated Images") # Show settings used display_text += "\nāļø Settings used:\n" for key, value in settings.items(): display_text += f" {key}: {value}\n" return display_text, gallery_html, str(result) else: error_msg = f"ā Generation failed: {response.status_code}" try: error_detail = response.json() error_msg += f"\nDetails: {error_detail.get('error', str(error_detail))}" except: error_msg += f"\nResponse: {response.text[:200]}..." return error_msg, "", error_msg except Exception as e: error_msg = f"ā Error: {str(e)}" return error_msg, "", error_msg # ========== OUTPUTS TAB ========== def create_gallery_html(image_urls, title="Gallery"): """Create HTML gallery with lightbox""" if not image_urls: return """
Create, browse, and manage your AI-generated images
""") with gr.Tabs(): # ========== GENERATE TAB ========== with gr.Tab("⨠Generate", id="generate"): with gr.Row(): with gr.Column(scale=1): with gr.Group(): gr.Markdown("### šÆ Prompt Settings") prompt = gr.Textbox( label="Prompt", placeholder="A beautiful sunset over mountains, digital art...", lines=3, value="A majestic dragon flying over a fantasy castle, digital art, epic lighting" ) negative_prompt = gr.Textbox( label="Negative Prompt (Optional)", placeholder="blurry, low quality, distorted...", lines=2, value="blurry, distorted, low quality, ugly" ) init_image_url = gr.Textbox( label="Init Image URL (Optional - for img2img)", placeholder="https://example.com/image.jpg", lines=1 ) prompt_strength = gr.Slider( label="Prompt Strength (for img2img)", minimum=0.0, maximum=1.0, value=0.8, step=0.1, visible=False ) # Show prompt strength only when init image is provided init_image_url.change( lambda x: gr.update(visible=bool(x and x.strip())), inputs=[init_image_url], outputs=[prompt_strength] ) with gr.Group(): gr.Markdown("### āļø Generation Settings") with gr.Row(): model_dropdown = gr.Dropdown( label="Model", choices=[m[0] for m in MODELS], value=DEFAULT_MODEL, info="Select which AI model to use" ) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=1024, value=768, step=8, info="Image width (pixels)" ) height = gr.Slider( label="Height", minimum=256, maximum=1024, value=768, step=8, info="Image height (pixels)" ) num_outputs = gr.Slider( label="Number of Images", minimum=1, maximum=4, value=1, step=1, info="How many images to generate (uses more credits)" ) guidance_scale = gr.Slider( label="Guidance Scale", minimum=1.0, maximum=20.0, value=7.0, step=0.5, info="How closely to follow the prompt" ) inference_steps = gr.Slider( label="Inference Steps", minimum=10, maximum=50, value=30, step=1, info="More steps = more detail but slower" ) with gr.Row(): scheduler_dropdown = gr.Dropdown( label="Scheduler", choices=[s[0] for s in SCHEDULERS], value=DEFAULT_SCHEDULER, info="Diffusion sampling method" ) seed = gr.Textbox( label="Seed (Optional)", placeholder="Leave empty for random", lines=1, info="Same seed + same settings = same image" ) with gr.Column(scale=1): generate_btn = gr.Button( "š Generate Image", variant="primary", size="lg", scale=1 ) with gr.Group(): gr.Markdown("### š Generation Results") generate_output = gr.Textbox( label="Status & Details", lines=15, interactive=False, value="Ready to generate! Enter a prompt and click the button above." ) with gr.Group(): gr.Markdown("### š¼ļø Generated Images") generate_gallery = gr.HTML( label="", value="""