kai-api-gateway / provider_state.py
KiWA001's picture
feat: Implement OpenCode Terminal Portal
2a87f4c
"""
Provider State Manager
----------------------
Manages enabled/disabled state of providers with Supabase persistence.
Uses kaiapi_ prefixed table names for multi-project organization.
"""
import logging
from typing import Dict, Optional
from db import get_supabase
from config import PROVIDERS
logger = logging.getLogger("kai_api.provider_state")
# Table name with kaiapi_ prefix
TABLE_NAME = "kaiapi_provider_states"
class ProviderStateManager:
"""Manages provider enable/disable state with Supabase persistence."""
def __init__(self):
self._providers: Dict[str, dict] = {}
self._initialized = False
async def initialize(self):
"""Load provider states from Supabase or use defaults."""
if self._initialized:
return
supabase = get_supabase()
if supabase:
try:
# Try to load from Supabase (using kaiapi_ prefixed table)
res = supabase.table(TABLE_NAME).select("*").execute()
if res.data:
# Load existing states
existing_ids = set()
for row in res.data:
provider_id = row.get("provider_id")
if provider_id:
self._providers[provider_id] = {
"enabled": row.get("enabled", False),
"name": row.get("name", provider_id),
"type": row.get("type", "api"),
}
existing_ids.add(provider_id)
# Add any new providers from PROVIDERS that aren't in DB
for provider_id, config in PROVIDERS.items():
if provider_id not in existing_ids:
try:
supabase.table(TABLE_NAME).insert({
"provider_id": provider_id,
"enabled": config["enabled"],
"name": config["name"],
"type": config["type"]
}).execute()
self._providers[provider_id] = config.copy()
logger.info(f"✅ Added new provider: {provider_id}")
except Exception as e:
logger.warning(f"Could not add provider {provider_id}: {e}")
logger.info(f"✅ Loaded {len(self._providers)} provider states from Supabase")
else:
# Initialize with defaults
await self._initialize_defaults(supabase)
except Exception as e:
logger.error(f"❌ Failed to load provider states: {e}")
# Fall back to defaults
self._providers = PROVIDERS.copy()
else:
# No Supabase, use defaults
self._providers = PROVIDERS.copy()
self._initialized = True
async def _initialize_defaults(self, supabase):
"""Initialize provider states with defaults in Supabase."""
for provider_id, config in PROVIDERS.items():
self._providers[provider_id] = config.copy()
try:
supabase.table(TABLE_NAME).insert({
"provider_id": provider_id,
"enabled": config["enabled"],
"name": config["name"],
"type": config["type"]
}).execute()
except Exception as e:
logger.warning(f"Could not insert provider {provider_id}: {e}")
logger.info(f"✅ Initialized {len(self._providers)} default provider states")
def get_provider_state(self, provider_id: str) -> Optional[dict]:
"""Get state for a specific provider."""
return self._providers.get(provider_id)
def is_enabled(self, provider_id: str) -> bool:
"""Check if a provider is enabled."""
provider = self._providers.get(provider_id)
return provider.get("enabled", False) if provider else False
def get_all_providers(self) -> Dict[str, dict]:
"""Get all provider states."""
return self._providers.copy()
def get_enabled_providers(self) -> Dict[str, dict]:
"""Get only enabled providers."""
return {
k: v for k, v in self._providers.items()
if v.get("enabled", False)
}
async def set_provider_state(self, provider_id: str, enabled: bool) -> bool:
"""Enable or disable a provider and persist to Supabase."""
if provider_id not in self._providers:
logger.error(f"Unknown provider: {provider_id}")
return False
# Update local state
self._providers[provider_id]["enabled"] = enabled
# Persist to Supabase
supabase = get_supabase()
if supabase:
try:
# Check if row exists
res = supabase.table(TABLE_NAME).select("id").eq("provider_id", provider_id).execute()
if res.data:
# Update existing
supabase.table(TABLE_NAME).update({
"enabled": enabled
}).eq("provider_id", provider_id).execute()
else:
# Insert new
supabase.table(TABLE_NAME).insert({
"provider_id": provider_id,
"enabled": enabled,
"name": self._providers[provider_id]["name"],
"type": self._providers[provider_id]["type"]
}).execute()
logger.info(f"✅ Provider '{provider_id}' {'enabled' if enabled else 'disabled'}")
return True
except Exception as e:
logger.error(f"❌ Failed to persist provider state: {e}")
return False
return True
def get_enabled_provider_ids(self) -> list:
"""Get list of enabled provider IDs."""
return [
provider_id
for provider_id, config in self._providers.items()
if config.get("enabled", False)
]
# Global instance
_provider_state_manager: Optional[ProviderStateManager] = None
async def get_provider_state_manager() -> ProviderStateManager:
"""Get the global provider state manager."""
global _provider_state_manager
if _provider_state_manager is None:
_provider_state_manager = ProviderStateManager()
await _provider_state_manager.initialize()
return _provider_state_manager
def get_provider_state_manager_sync() -> ProviderStateManager:
"""Get provider state manager without async (for sync contexts)."""
global _provider_state_manager
if _provider_state_manager is None:
_provider_state_manager = ProviderStateManager()
# Initialize with defaults synchronously
_provider_state_manager._providers = PROVIDERS.copy()
_provider_state_manager._initialized = True
return _provider_state_manager