safetyAI / app.py
al1kss's picture
Update app.py
04934a2 verified
raw
history blame
58.3 kB
import os
import json
import uuid
import logging
from datetime import datetime, timedelta
from pathlib import Path
from typing import List, Optional, Dict, Any
import bcrypt
import asyncio
import PyPDF2
from io import BytesIO
from fastapi import FastAPI, HTTPException, Depends, UploadFile, File, Form, Security
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, EmailStr
import jwt
import hashlib
import secrets
# Import the corrected persistent LightRAG manager
from lightrag_manager import (
PersistentLightRAGManager,
CloudflareWorker,
initialize_lightrag_manager,
get_lightrag_manager,
validate_environment,
EnvironmentError,
RAGConfig
)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Environment variables
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32))
JWT_ALGORITHM = "HS256"
JWT_EXPIRATION_HOURS = 24 * 7 # 1 week
# File upload settings
MAX_UPLOAD_SIZE = int(os.getenv("MAX_UPLOAD_SIZE", "10485760")) # 10MB
# Create FastAPI app
app = FastAPI(
title="YourAI - Complete LightRAG Backend",
version="4.0.0",
description="Complete LightRAG system with Vercel-only persistence and JWT authentication"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Security
security = HTTPBearer()
# Pydantic models
class UserRegisterRequest(BaseModel):
email: EmailStr
name: str
password: str
class UserLoginRequest(BaseModel):
email: EmailStr
password: str
class QuestionRequest(BaseModel):
question: str
mode: Optional[str] = "hybrid"
conversation_id: Optional[str] = None
class QuestionResponse(BaseModel):
answer: str
mode: str
status: str
conversation_id: Optional[str] = None
class CustomAIRequest(BaseModel):
name: str
description: str
class FileUploadResponse(BaseModel):
filename: str
size: int
message: str
class UserResponse(BaseModel):
id: str
email: str
name: str
created_at: str
class AuthResponse(BaseModel):
token: str
user: UserResponse
message: str
# Global variables
lightrag_manager: Optional[PersistentLightRAGManager] = None
# JWT utilities
def create_jwt_token(user_data: dict) -> str:
"""Create JWT token for user"""
payload = {
"user_id": user_data["id"],
"email": user_data["email"],
"name": user_data["name"],
"exp": datetime.utcnow() + timedelta(hours=JWT_EXPIRATION_HOURS),
"iat": datetime.utcnow(),
"type": "access"
}
return jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
def verify_jwt_token(token: str) -> dict:
"""Verify JWT token and return payload"""
try:
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")
def hash_email(email: str) -> str:
"""Create a hash of email for user ID"""
return hashlib.md5(email.encode()).hexdigest()
# Authentication dependency
async def get_current_user(credentials: HTTPAuthorizationCredentials = Security(security)) -> dict:
"""Get current user from JWT token"""
token = credentials.credentials
payload = verify_jwt_token(token)
return {
"id": payload["user_id"],
"email": payload["email"],
"name": payload["name"]
}
# Optional authentication (for endpoints that work with or without auth)
async def get_current_user_optional(authorization: Optional[str] = None) -> Optional[dict]:
"""Get current user if authenticated, otherwise return None"""
if not authorization:
return None
try:
token = authorization.replace("Bearer ", "")
payload = verify_jwt_token(token)
return {
"id": payload["user_id"],
"email": payload["email"],
"name": payload["name"]
}
except:
return None
# Replace the startup_event function in app.py
@app.on_event("startup")
async def startup_event():
global lightrag_manager
logger.info("🚀 Starting up YourAI with Complete LightRAG Backend...")
try:
# Validate environment variables
validate_environment()
# Initialize LightRAG manager
lightrag_manager = await initialize_lightrag_manager()
# Test database connection
async with lightrag_manager.db.pool.acquire() as conn:
result = await conn.fetchval("SELECT 1")
logger.info(f"✅ Database connection successful: {result}")
# Clean up duplicate fire safety RAG instances first
logger.info("🧹 Cleaning up duplicate RAG instances...")
active_instance = await lightrag_manager.db.cleanup_duplicate_rag_instances("fire-safety", keep_latest=True)
if lightrag_manager and lightrag_manager.db:
await lightrag_manager.db.initialize_system_stats()
logger.info("✅ System stats initialized")
if active_instance:
logger.info(f"🎯 Using active fire safety RAG: {active_instance['name']} (ID: {active_instance['id']})")
# Create RAG config
config = RAGConfig(
ai_type="fire-safety",
user_id=None,
ai_id=None,
name="Fire Safety Expert",
description="Expert in fire safety regulations and procedures"
)
# Try to load existing RAG from database with DIRECT QUERY TEST
logger.info("🔍 Attempting to load existing fire safety RAG from DATABASE...")
try:
rag_instance = await lightrag_manager._load_from_database(config)
if rag_instance:
# Cache the loaded instance
lightrag_manager.rag_instances[config.get_cache_key()] = rag_instance
logger.info("🎉 SUCCESS: Fire safety RAG loaded from database and QUERY TESTED")
logger.info("✅ NO RECREATION NEEDED - RAG is working!")
else:
logger.info("🔧 Database loading/testing failed - creating new RAG instance")
logger.info("⏳ This will take 3-5 minutes to process all documents...")
# Create new RAG instance
rag_instance = await lightrag_manager._create_new_rag_instance(config)
# Save to database
await lightrag_manager._save_to_database(config, rag_instance)
# Cache the new instance
lightrag_manager.rag_instances[config.get_cache_key()] = rag_instance
logger.info("✅ Created and saved new fire safety RAG instance")
except Exception as e:
logger.error(f"❌ Error during fire safety RAG initialization: {e}")
# Create new as fallback
logger.info("🔧 Creating NEW fire safety RAG (fallback due to error)")
logger.info("⏳ This will take 3-5 minutes to process all documents...")
rag_instance = await lightrag_manager._create_new_rag_instance(config)
await lightrag_manager._save_to_database(config, rag_instance)
lightrag_manager.rag_instances[config.get_cache_key()] = rag_instance
logger.info("✅ Complete LightRAG system initialized successfully")
logger.info(f"🎯 Final cached RAG instances: {list(lightrag_manager.rag_instances.keys())}")
except EnvironmentError as e:
logger.error(f"❌ Environment validation failed: {e}")
raise
except Exception as e:
logger.error(f"❌ Startup failed: {e}")
raise
# Shutdown event
@app.on_event("shutdown")
async def shutdown_event():
global lightrag_manager
if lightrag_manager:
await lightrag_manager.cleanup()
logger.info("Shutdown complete")
# Root endpoint
@app.get("/")
async def root():
return {
"message": "YourAI - Complete LightRAG Backend API",
"version": "4.0.0",
"status": "running",
"features": {
"persistent_lightrag": True,
"vercel_only_storage": True,
"jwt_authentication": True,
"conversation_memory": True,
"custom_ai_support": True,
"zero_token_waste": True
}
}
# Stats response model
class StatsResponse(BaseModel):
total_users: int
total_ais: int
total_messages: int
lines_of_code_generated: int # This will be total characters
last_updated: str
class StatsUpdateRequest(BaseModel):
stat_type: str # 'users', 'ais', 'messages', 'characters'
increment: int = 1
# Get current stats
@app.get("/api/stats", response_model=StatsResponse)
async def get_stats():
"""Get current system statistics"""
if not lightrag_manager:
raise HTTPException(status_code=503, detail="LightRAG system not initialized")
try:
async with lightrag_manager.db.pool.acquire() as conn:
# Get the latest stats entry
stats_row = await conn.fetchrow("""
SELECT total_users, total_ais, total_messages, date
FROM system_stats
ORDER BY date DESC
LIMIT 1
""")
if not stats_row:
# If no stats exist, create initial entry
await conn.execute("""
INSERT INTO system_stats (total_users, total_ais, total_messages, date)
VALUES (0, 0, 0, NOW())
""")
return StatsResponse(
total_users=0,
total_ais=0,
total_messages=0,
lines_of_code_generated=0,
last_updated=datetime.now().isoformat()
)
# Calculate total characters from all messages (our "lines of code")
total_characters = await conn.fetchval("""
SELECT COALESCE(SUM(LENGTH(content)), 0)
FROM messages
""")
# Also include conversation_messages table if it exists
conversation_chars = await conn.fetchval("""
SELECT COALESCE(SUM(LENGTH(content)), 0)
FROM conversation_messages
""") or 0
total_lines_of_code = total_characters + conversation_chars
return StatsResponse(
total_users=stats_row['total_users'],
total_ais=stats_row['total_ais'],
total_messages=stats_row['total_messages'],
lines_of_code_generated=total_lines_of_code,
last_updated=stats_row['date'].isoformat()
)
except Exception as e:
logger.error(f"Failed to get stats: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve stats")
# Update stats (for real-time increments)
@app.post("/api/stats/update")
async def update_stats(request: StatsUpdateRequest):
"""Update system statistics"""
if not lightrag_manager:
raise HTTPException(status_code=503, detail="LightRAG system not initialized")
try:
async with lightrag_manager.db.pool.acquire() as conn:
today = datetime.now().date()
# Get or create today's stats entry
stats_row = await conn.fetchrow("""
SELECT total_users, total_ais, total_messages
FROM system_stats
WHERE DATE(date) = $1
""", today)
if not stats_row:
# Create today's entry based on yesterday's data
yesterday_stats = await conn.fetchrow("""
SELECT total_users, total_ais, total_messages
FROM system_stats
ORDER BY date DESC
LIMIT 1
""")
base_users = yesterday_stats['total_users'] if yesterday_stats else 0
base_ais = yesterday_stats['total_ais'] if yesterday_stats else 0
base_messages = yesterday_stats['total_messages'] if yesterday_stats else 0
await conn.execute("""
INSERT INTO system_stats (total_users, total_ais, total_messages, date)
VALUES ($1, $2, $3, NOW())
""", base_users, base_ais, base_messages)
current_users = base_users
current_ais = base_ais
current_messages = base_messages
else:
current_users = stats_row['total_users']
current_ais = stats_row['total_ais']
current_messages = stats_row['total_messages']
# Update the specific stat
if request.stat_type == 'users':
new_users = current_users + request.increment
await conn.execute("""
UPDATE system_stats
SET total_users = $1, date = NOW()
WHERE DATE(date) = $2
""", new_users, today)
elif request.stat_type == 'ais':
new_ais = current_ais + request.increment
await conn.execute("""
UPDATE system_stats
SET total_ais = $1, date = NOW()
WHERE DATE(date) = $2
""", new_ais, today)
elif request.stat_type == 'messages':
new_messages = current_messages + request.increment
await conn.execute("""
UPDATE system_stats
SET total_messages = $1, date = NOW()
WHERE DATE(date) = $2
""", new_messages, today)
return {"message": f"Updated {request.stat_type} by {request.increment}", "status": "success"}
except Exception as e:
logger.error(f"Failed to update stats: {e}")
raise HTTPException(status_code=500, detail="Failed to update stats")
# Helper function to increment stats (call this from other endpoints)
async def increment_stat(stat_type: str, increment: int = 1):
"""Helper function to increment stats from other endpoints"""
try:
if not lightrag_manager:
return
async with lightrag_manager.db.pool.acquire() as conn:
today = datetime.now().date()
# Upsert today's stats
if stat_type == 'users':
await conn.execute("""
INSERT INTO system_stats (total_users, total_ais, total_messages, date)
VALUES ($1, 0, 0, NOW())
ON CONFLICT (date) DO UPDATE SET
total_users = system_stats.total_users + $1,
date = NOW()
""", increment)
elif stat_type == 'ais':
await conn.execute("""
INSERT INTO system_stats (total_users, total_ais, total_messages, date)
VALUES (0, $1, 0, NOW())
ON CONFLICT (date) DO UPDATE SET
total_ais = system_stats.total_ais + $1,
date = NOW()
""", increment)
elif stat_type == 'messages':
await conn.execute("""
INSERT INTO system_stats (total_users, total_ais, total_messages, date)
VALUES (0, 0, $1, NOW())
ON CONFLICT (date) DO UPDATE SET
total_messages = system_stats.total_messages + $1,
date = NOW()
""", increment)
except Exception as e:
logger.error(f"Failed to increment {stat_type}: {e}")
# Health check
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"lightrag_initialized": lightrag_manager is not None,
"vercel_storage": True,
"database_connected": lightrag_manager.db.pool is not None if lightrag_manager else False,
"redis_connected": lightrag_manager.db.redis is not None if lightrag_manager else False,
"environment_validated": True
}
@app.post("/auth/register", response_model=AuthResponse)
async def register_user(request: UserRegisterRequest):
"""Register a new user with FIXED validation"""
if not lightrag_manager:
raise HTTPException(status_code=503, detail="LightRAG system not initialized")
try:
# Validate password (basic validation)
if len(request.password) < 8:
raise HTTPException(status_code=400, detail="Password must be at least 8 characters long")
# Check if user already exists
async with lightrag_manager.db.pool.acquire() as conn:
existing_user = await conn.fetchrow("""
SELECT id FROM users WHERE email = $1
""", request.email)
if existing_user:
raise HTTPException(status_code=400, detail="User already exists")
# Hash password
password_hash = bcrypt.hashpw(request.password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
# Create user in database
user_id = str(uuid.uuid4())
await conn.execute("""
INSERT INTO users (id, email, name, password, password_hash, hashed_email, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW())
""", user_id, request.email, request.name, request.password, password_hash,
hashlib.md5(request.email.encode()).hexdigest())
user = {
"id": user_id,
"email": request.email,
"name": request.name,
"created_at": datetime.now().isoformat()
}
# Create JWT token
token = create_jwt_token(user)
await lightrag_manager.db.update_system_stat('users', 1)
logger.info(f"📈 User registered and stats updated: {request.email}")
return AuthResponse(
message="User created successfully",
user=UserResponse(**user),
token=token
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Registration failed: {e}")
raise HTTPException(status_code=500, detail=f"Registration failed: {str(e)}")
@app.post("/auth/login", response_model=AuthResponse)
async def login_user(request: UserLoginRequest):
"""Login user with FIXED password validation"""
if not lightrag_manager:
raise HTTPException(status_code=503, detail="LightRAG system not initialized")
try:
# Get user from database
async with lightrag_manager.db.pool.acquire() as conn:
user_record = await conn.fetchrow("""
SELECT id, email, name, password_hash, created_at
FROM users
WHERE email = $1 AND is_active = true
""", request.email)
if not user_record:
raise HTTPException(status_code=401, detail="Invalid email or password")
# FIXED: Add password to login request
if not hasattr(request, 'password') or not request.password:
raise HTTPException(status_code=400, detail="Password is required")
# Verify password
if not bcrypt.checkpw(request.password.encode('utf-8'), user_record['password_hash'].encode('utf-8')):
raise HTTPException(status_code=401, detail="Invalid email or password")
user = {
"id": str(user_record['id']),
"email": user_record['email'],
"name": user_record['name'],
"created_at": user_record['created_at'].isoformat()
}
# Create JWT token
token = create_jwt_token(user)
logger.info(f"User logged in: {request.email}")
return AuthResponse(
token=token,
user=UserResponse(**user),
message="Login successful"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Login failed: {e}")
raise HTTPException(status_code=500, detail=f"Login failed: {str(e)}")
@app.get("/debug/rag-modes")
async def debug_rag_modes():
"""Debug endpoint to test all RAG modes"""
if not lightrag_manager:
return {"error": "LightRAG manager not initialized"}
test_question = "What are fire exit requirements?"
results = {}
for mode in ["naive", "local", "global", "hybrid"]:
try:
rag_instance = await lightrag_manager.get_or_create_rag_instance("fire-safety")
from lightrag import QueryParam
response = await rag_instance.aquery(test_question, QueryParam(mode=mode))
results[mode] = {
"status": "success" if response and not response.startswith("Sorry") else "failed",
"response_length": len(response) if response else 0,
"response_preview": response[:200] if response else "No response"
}
except Exception as e:
results[mode] = {
"status": "error",
"error": str(e)
}
return {
"test_question": test_question,
"mode_results": results,
"working_modes": [mode for mode, result in results.items() if result["status"] == "success"]
}
# Fire Safety Chat
@app.post("/chat/fire-safety", response_model=QuestionResponse)
async def chat_fire_safety(
request: QuestionRequest,
current_user: dict = Depends(get_current_user)
):
"""Chat with fire safety AI using fallback modes"""
if not lightrag_manager:
raise HTTPException(status_code=503, detail="LightRAG system not initialized")
try:
# Generate conversation ID if not provided
conversation_id = request.conversation_id or str(uuid.uuid4())
# Query with fallback
result = await query_rag_with_fallback(
lightrag_manager=lightrag_manager,
ai_type="fire-safety",
question=request.question,
conversation_id=conversation_id,
user_id=current_user["id"],
preferred_mode=request.mode or "hybrid"
)
await update_message_stats(request.question, result["answer"])
return QuestionResponse(
answer=result["answer"],
mode=result["mode"],
status=result["status"],
conversation_id=conversation_id
)
except Exception as e:
logger.error(f"Fire safety chat error: {e}")
raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}")
# General Chat
@app.post("/chat/general", response_model=QuestionResponse)
async def chat_general(
request: QuestionRequest,
current_user: dict = Depends(get_current_user)
):
"""General AI chat"""
if not lightrag_manager:
raise HTTPException(status_code=503, detail="LightRAG system not initialized")
try:
# Generate conversation ID if not provided
conversation_id = request.conversation_id or str(uuid.uuid4())
# Use Cloudflare worker directly for general chat
system_prompt = """You are a helpful general AI assistant. Provide accurate, helpful, and engaging responses. If they ask what AI model are you, answer you were trained and dont know the exact model name."""
response = await lightrag_manager.cloudflare_worker.query(
request.question,
system_prompt
)
# Save conversation
await lightrag_manager.db.save_conversation_message(
conversation_id, "user", request.question, {
"user_id": current_user["id"],
"ai_type": "general"
}
)
await lightrag_manager.db.save_conversation_message(
conversation_id, "assistant", response, {
"mode": request.mode or "direct",
"ai_type": "general",
"user_id": current_user["id"]
}
)
await update_message_stats(request.question, response)
return QuestionResponse(
answer=response,
mode=request.mode or "direct",
status="success",
conversation_id=conversation_id
)
except Exception as e:
logger.error(f"General chat error: {e}")
raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}")
# Custom AI Chat
# Custom AI Chat - FIXED VERSION with Fallback
@app.post("/chat/custom/{ai_id}", response_model=QuestionResponse)
async def chat_custom_ai(
ai_id: str,
request: QuestionRequest,
current_user: dict = Depends(get_current_user)
):
"""Chat with custom AI using fallback modes"""
if not lightrag_manager:
raise HTTPException(status_code=503, detail="LightRAG system not initialized")
try:
# Generate conversation ID if not provided
conversation_id = request.conversation_id or str(uuid.uuid4())
# Query with fallback - SAME AS FIRE SAFETY
result = await query_rag_with_fallback(
lightrag_manager=lightrag_manager,
ai_type="custom",
question=request.question,
conversation_id=conversation_id,
user_id=current_user["id"],
ai_id=ai_id, # This is the key difference - pass ai_id
preferred_mode=request.mode or "hybrid"
)
await update_message_stats(request.question, result["answer"])
return QuestionResponse(
answer=result["answer"],
mode=result["mode"],
status=result["status"],
conversation_id=conversation_id
)
except Exception as e:
logger.error(f"Custom AI chat error: {e}")
raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}")
# Alternative: Enhanced fallback system specifically for custom AIs
async def query_custom_ai_with_fallback(
lightrag_manager: PersistentLightRAGManager,
ai_id: str,
question: str,
conversation_id: str,
user_id: str,
preferred_mode: str = "hybrid"
) -> Dict[str, Any]:
"""Query custom AI with automatic fallback to working modes and Cloudflare backup"""
# Try modes in order of preference
fallback_modes = ["hybrid", "local", "global", "naive"]
# Start with user's preferred mode
if preferred_mode in fallback_modes:
fallback_modes.remove(preferred_mode)
fallback_modes.insert(0, preferred_mode)
last_error = None
for mode in fallback_modes:
try:
logger.info(f"🔍 Trying {mode} mode for custom AI {ai_id}")
response = await lightrag_manager.query_with_memory(
ai_type="custom",
question=question,
conversation_id=conversation_id,
user_id=user_id,
ai_id=ai_id,
mode=mode
)
# Check if response is valid (not a "Sorry" message)
if response and len(response.strip()) > 20 and not response.startswith("Sorry"):
logger.info(f"✅ {mode} mode worked for custom AI {ai_id}")
return {
"answer": response,
"mode": mode,
"status": "success",
"fallback_used": mode != preferred_mode
}
else:
logger.warning(f"⚠️ {mode} mode returned empty/error response for custom AI {ai_id}")
last_error = f"{mode} mode returned: {response[:100]}..."
except Exception as e:
logger.warning(f"⚠️ {mode} mode failed for custom AI {ai_id}: {e}")
last_error = str(e)
continue
# If all LightRAG modes fail, fallback to Cloudflare AI
try:
logger.info(f"🔄 All LightRAG modes failed for custom AI {ai_id}, falling back to Cloudflare AI")
# Get custom AI details for context
async with lightrag_manager.db.pool.acquire() as conn:
ai_details = await conn.fetchrow("""
SELECT name, description FROM rag_instances
WHERE ai_id = $1 AND ai_type = 'custom' AND status = 'active'
""", ai_id)
# Create context-aware system prompt
ai_name = ai_details['name'] if ai_details else "Custom AI"
ai_description = ai_details['description'] if ai_details else "a custom AI assistant"
system_prompt = f"""You are {ai_name}, {ai_description}.
Although you don't have access to your specific knowledge base right now,
provide the best general assistance you can. Be helpful, accurate, and
acknowledge that you're operating in a general mode without your specialized knowledge."""
fallback_response = await lightrag_manager.cloudflare_worker.query(
question,
system_prompt
)
# Save fallback response to database
await lightrag_manager.db.save_conversation_message(
conversation_id, "user", question, {
"user_id": user_id,
"ai_type": "custom",
"ai_id": ai_id
}
)
await lightrag_manager.db.save_conversation_message(
conversation_id, "assistant", fallback_response, {
"mode": "cloudflare_fallback",
"ai_type": "custom",
"ai_id": ai_id,
"user_id": user_id,
"error": last_error
}
)
return {
"answer": fallback_response,
"mode": "cloudflare_fallback",
"status": "success",
"fallback_used": True
}
except Exception as fallback_error:
logger.error(f"❌ Even Cloudflare fallback failed for custom AI {ai_id}: {fallback_error}")
# Last resort - return informative error
error_response = f"I'm having trouble accessing my knowledge base right now. Please try again in a moment, or contact support if the issue persists. (Error: {last_error})"
return {
"answer": error_response,
"mode": "error",
"status": "error",
"fallback_used": True
}
@app.post("/admin/rebuild-custom-ai/{ai_id}")
async def rebuild_custom_ai(
ai_id: str,
current_user: dict = Depends(get_current_user)
):
"""Force rebuild a corrupted custom AI"""
if not lightrag_manager:
raise HTTPException(status_code=503, detail="LightRAG system not initialized")
try:
success = await lightrag_manager.force_rebuild_custom_ai(ai_id, current_user["id"])
if success:
return {"message": f"Successfully rebuilt custom AI {ai_id}"}
else:
raise HTTPException(status_code=500, detail="Failed to rebuild custom AI")
except Exception as e:
logger.error(f"Rebuild custom AI error: {e}")
raise HTTPException(status_code=500, detail=f"Rebuild error: {str(e)}")
@app.post("/chat/custom/{ai_id}", response_model=QuestionResponse)
async def chat_custom_ai_enhanced(
ai_id: str,
request: QuestionRequest,
current_user: dict = Depends(get_current_user)
):
"""Chat with custom AI using enhanced fallback system"""
if not lightrag_manager:
raise HTTPException(status_code=503, detail="LightRAG system not initialized")
try:
# Generate conversation ID if not provided
conversation_id = request.conversation_id or str(uuid.uuid4())
# Query with enhanced custom AI fallback
result = await query_custom_ai_with_fallback(
lightrag_manager=lightrag_manager,
ai_id=ai_id,
question=request.question,
conversation_id=conversation_id,
user_id=current_user["id"],
preferred_mode=request.mode or "hybrid"
)
await update_message_stats(request.question, result["answer"])
return QuestionResponse(
answer=result["answer"],
mode=result["mode"],
status=result["status"],
conversation_id=conversation_id
)
except Exception as e:
logger.error(f"Custom AI chat error: {e}")
raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}")
@app.post("/admin/stats/recalculate")
async def recalculate_stats(current_user: dict = Depends(get_current_user)):
"""Recalculate stats from actual database counts (admin only)"""
# You might want to add admin check here
if not lightrag_manager:
raise HTTPException(status_code=503, detail="LightRAG system not initialized")
try:
await lightrag_manager.db.initialize_system_stats()
return {"message": "Stats recalculated successfully", "status": "success"}
except Exception as e:
logger.error(f"Failed to recalculate stats: {e}")
raise HTTPException(status_code=500, detail="Failed to recalculate stats")
# File upload endpoint
@app.post("/upload-files", response_model=List[FileUploadResponse])
async def upload_files(
files: List[UploadFile] = File(...),
current_user: dict = Depends(get_current_user)
):
"""Upload files for custom AI (same processing as fire-safety)"""
uploaded_files = []
allowed_extensions = {'.txt', '.md', '.json', '.pdf', '.docx'}
for file in files:
if file.size > MAX_UPLOAD_SIZE:
raise HTTPException(
status_code=413,
detail=f"File {file.filename} too large. Max size: {MAX_UPLOAD_SIZE} bytes"
)
file_extension = Path(file.filename).suffix.lower()
if file_extension not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"File type {file_extension} not allowed. Allowed: {allowed_extensions}"
)
# Read file content
content = await file.read()
# Process content based on file type (same as fire-safety)
if file_extension == '.pdf':
try:
# Same PDF processing as in lightrag_manager._load_fire_safety_knowledge()
pdf_reader = PyPDF2.PdfReader(BytesIO(content))
text_content = ""
for page_num in range(min(20, len(pdf_reader.pages))): # Same limit as fire-safety
page_text = pdf_reader.pages[page_num].extract_text()
if page_text and len(page_text.strip()) > 100:
text_content += f"Page {page_num + 1}: {page_text[:3000]}\n" # Same chunking
logger.info(f"Processed PDF {file.filename}: {len(text_content)} characters")
except Exception as e:
logger.warning(f"PDF processing failed for {file.filename}: {e}")
text_content = content.decode('utf-8', errors='ignore')
else:
text_content = content.decode('utf-8', errors='ignore')
# Store in temporary storage (same as original)
if not hasattr(upload_files, 'temp_storage'):
upload_files.temp_storage = {}
if current_user["id"] not in upload_files.temp_storage:
upload_files.temp_storage[current_user["id"]] = []
upload_files.temp_storage[current_user["id"]].append({
"filename": file.filename,
"content": text_content,
"type": file_extension.lstrip('.'),
"size": len(content),
"processed_chars": len(text_content)
})
uploaded_files.append(FileUploadResponse(
filename=file.filename,
size=len(content),
message=f"Processed successfully ({len(text_content)} characters)"
))
return uploaded_files
# Create custom AI
@app.post("/create-custom-ai")
async def create_custom_ai(
ai_data: CustomAIRequest,
current_user: dict = Depends(get_current_user)
):
"""Create custom AI using EXACT fire-safety chunking pattern"""
if not lightrag_manager:
raise HTTPException(status_code=503, detail="LightRAG system not initialized")
# Check for uploaded files
if not hasattr(upload_files, 'temp_storage') or current_user["id"] not in upload_files.temp_storage:
raise HTTPException(status_code=400, detail="No files uploaded. Please upload knowledge files first.")
uploaded_files = upload_files.temp_storage[current_user["id"]]
if not uploaded_files:
raise HTTPException(status_code=400, detail="No files found. Please upload knowledge files first.")
try:
# Generate AI ID
ai_id = str(uuid.uuid4())
logger.info(f"🔥 Creating custom AI: {ai_data.name} for user {current_user['id']} with ID {ai_id}")
# Use EXACT same method as fire-safety
rag_instance = await lightrag_manager.get_or_create_rag_instance(
ai_type="custom",
user_id=current_user["id"],
ai_id=ai_id,
name=ai_data.name,
description=ai_data.description
)
logger.info(f"✅ RAG instance created for custom AI {ai_id}")
# CRITICAL: Process files EXACTLY like fire-safety does with chunking
all_content = []
for file_data in uploaded_files:
content = file_data["content"]
filename = file_data["filename"]
logger.info(f"📄 Processing {filename}: {len(content)} chars")
# EXACT same chunking as fire-safety
if len(content) > 3000:
# Split into 3000 character chunks (same as fire-safety)
for i in range(0, min(len(content), 60000), 3000):
chunk = content[i:i + 3000]
if chunk.strip():
all_content.append(f"{filename} Section {i // 3000 + 1}: {chunk}")
logger.info(f"📝 Created chunk {i // 3000 + 1} for {filename}: {len(chunk)} chars")
else:
all_content.append(f"{filename}: {content}")
logger.info(f"📚 Starting insertion of {len(all_content)} chunks (same as fire-safety)")
successful_insertions = 0
# Insert chunks one by one (EXACT same as fire-safety)
for i, content_chunk in enumerate(all_content):
try:
logger.info(f"📝 Inserting chunk {i + 1}/{len(all_content)}: ({len(content_chunk)} chars)")
# Insert with same timeout as fire-safety
insertion_task = asyncio.create_task(rag_instance.ainsert(content_chunk))
try:
await asyncio.wait_for(insertion_task, timeout=45.0)
successful_insertions += 1
logger.info(f"✅ Chunk {i + 1} inserted successfully")
# Same pause as fire-safety
await asyncio.sleep(2)
except asyncio.TimeoutError:
logger.error(f"⏰ Chunk {i + 1} insertion timed out after 45 seconds")
insertion_task.cancel()
continue
except Exception as e:
logger.error(f"❌ Failed to insert chunk {i + 1}: {e}")
continue
logger.info(f"✅ Successfully inserted {successful_insertions}/{len(uploaded_files)} files")
# Wait for processing to complete (same as fire-safety)
if successful_insertions > 0:
logger.info("🔍 Final validation and cleaning...")
await asyncio.sleep(5) # Wait for processing
# Check storage was created properly
entities_count = await lightrag_manager._count_entities(rag_instance)
chunks_count = await lightrag_manager._count_chunks(rag_instance)
relationships_count = await lightrag_manager._count_relationships(rag_instance)
logger.info(
f"📊 Final counts: {chunks_count} chunks, {entities_count} entities, {relationships_count} relationships")
if entities_count > 0:
logger.info("🎉 Entity extraction SUCCESS - Custom AI should work!")
else:
logger.warning("⚠️ No entities extracted - may affect performance")
# SAVE RAG TO DATABASE (like fire-safety does) - MOVED HERE!
logger.info(f"💾 Saving RAG instance to database...")
# Create RAG config (same pattern as fire-safety)
from lightrag_manager import RAGConfig
config = RAGConfig(
ai_type="custom",
user_id=current_user["id"],
ai_id=ai_id,
name=ai_data.name,
description=ai_data.description
)
# Save to database (this is what was missing!)
await lightrag_manager._save_to_database(config, rag_instance)
logger.info(f"✅ RAG instance saved to database successfully")
# Clear temporary storage
upload_files.temp_storage[current_user["id"]] = []
# Return AI info
ai_info = {
"id": ai_id,
"name": ai_data.name,
"description": ai_data.description,
"created_at": datetime.now().isoformat(),
"files_count": len(uploaded_files),
"chunks_processed": successful_insertions,
"persistent_storage": True,
"lightrag_python": True
}
await lightrag_manager.db.update_system_stat('ais', 1)
logger.info(f"📈 Custom AI created and stats updated: {ai_data.name}")
return {
"ai_id": ai_id,
"message": "Custom AI created successfully",
"ai_info": ai_info
}
except Exception as e:
logger.error(f"❌ Error creating custom AI: {e}")
raise HTTPException(status_code=500, detail=f"Failed to create custom AI: {str(e)}")
async def update_message_stats(user_message: str, ai_response: str):
"""Update message statistics including character count"""
try:
if not lightrag_manager:
return
# Count characters from both messages
user_chars = len(user_message) if user_message else 0
ai_chars = len(ai_response) if ai_response else 0
total_chars = user_chars + ai_chars
# Update message count (2 messages: user + AI)
await lightrag_manager.db.update_system_stat('messages', 2)
logger.debug(f"📊 Message stats updated: {total_chars} characters, 2 messages")
except Exception as e:
logger.warning(f"Failed to update message stats: {e}")
@app.get("/custom-ai/{ai_id}")
async def get_custom_ai_details(
ai_id: str,
current_user: dict = Depends(get_current_user)
):
"""Get custom AI details with ownership validation"""
try:
if not lightrag_manager:
raise HTTPException(status_code=503, detail="LightRAG system not initialized")
async with lightrag_manager.db.pool.acquire() as conn:
# Get custom AI details from rag_instances table
ai_details = await conn.fetchrow("""
SELECT
ri.ai_id as id,
ri.user_id,
ri.name,
ri.description,
ri.created_at,
ri.file_count as files_count,
ri.total_tokens,
ri.file_count,
ri.total_chunks
FROM rag_instances ri
WHERE ri.ai_id = $1 AND ri.ai_type = 'custom' AND ri.status = 'active'
""", ai_id)
if not ai_details:
raise HTTPException(status_code=404, detail="Custom AI not found")
# Check ownership
if ai_details['user_id'] != current_user["id"]:
raise HTTPException(status_code=403, detail="Access denied. You can only access AI assistants you created.")
# Convert to dict and format
ai_data = dict(ai_details)
ai_data['created_at'] = ai_data['created_at'].isoformat()
ai_data['files_count'] = ai_data['file_count'] or 0
logger.info(f"✅ Custom AI details retrieved for user {current_user['id']}: {ai_data['name']}")
return ai_data
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting custom AI details: {e}")
raise HTTPException(status_code=500, detail="Failed to get AI details")
@app.get("/my-ais")
async def get_my_ais(current_user: dict = Depends(get_current_user)):
"""Get user's custom AIs from database"""
try:
if not lightrag_manager:
return {"ais": [], "count": 0}
async with lightrag_manager.db.pool.acquire() as conn:
# Get RAG instances for this user's custom AIs
rag_instances = await conn.fetch("""
SELECT
ri.ai_id as id,
ri.name,
ri.description,
ri.created_at,
ri.total_chunks,
ri.total_tokens,
ri.file_count,
COUNT(kf.id) as actual_files
FROM rag_instances ri
LEFT JOIN knowledge_files kf ON kf.rag_instance_id = ri.id
WHERE ri.user_id = $1 AND ri.ai_type = 'custom' AND ri.status = 'active'
GROUP BY ri.ai_id, ri.name, ri.description, ri.created_at, ri.total_chunks, ri.total_tokens, ri.file_count
ORDER BY ri.created_at DESC
""", current_user["id"])
formatted_ais = []
for ai in rag_instances:
formatted_ais.append({
"id": ai['id'],
"name": ai['name'],
"description": ai['description'],
"created_at": ai['created_at'].isoformat(),
"files_count": ai['actual_files'] or ai['file_count'] or 0,
"persistent_storage": True,
"lightrag_python": True,
"total_tokens": ai['total_tokens'] or 0,
"knowledge_size": ai['total_tokens'] * 4 if ai['total_tokens'] else 0 # Rough estimate
})
return {"ais": formatted_ais, "count": len(formatted_ais)}
except Exception as e:
logger.error(f"❌ Error getting user AIs: {e}")
return {"ais": [], "count": 0}
# Get user's conversations
@app.get("/conversations")
async def get_conversations(current_user: dict = Depends(get_current_user)):
"""Get user's conversations"""
if not lightrag_manager:
return {"conversations": []}
try:
# Get actual conversations from database
async with lightrag_manager.db.pool.acquire() as conn:
conversations = await conn.fetch("""
SELECT id, title, ai_type, ai_id, created_at, updated_at
FROM conversations
WHERE user_id = $1 AND is_active = true
ORDER BY updated_at DESC
LIMIT 50
""", current_user["id"])
formatted_conversations = []
for conv in conversations:
formatted_conversations.append({
"id": str(conv['id']),
"title": conv['title'] or "New Conversation",
"ai_type": conv['ai_type'],
"ai_id": conv['ai_id'],
"last_message": "Click to view conversation",
"updated_at": conv['updated_at'].isoformat()
})
return {"conversations": formatted_conversations}
except Exception as e:
logger.error(f"Error getting conversations: {e}")
return {"conversations": []}
# Get conversation messages
@app.get("/conversations/{conversation_id}")
async def get_conversation_messages(
conversation_id: str,
current_user: dict = Depends(get_current_user)
):
"""Get messages for a specific conversation"""
try:
messages = await lightrag_manager.db.get_conversation_messages(conversation_id)
formatted_messages = [
{
"id": str(msg.get("id", uuid.uuid4())),
"role": msg["role"],
"content": msg["content"],
"created_at": msg.get("created_at", datetime.now().isoformat())
}
for msg in messages
]
return {"messages": formatted_messages}
except Exception as e:
logger.error(f"Error getting conversation messages: {e}")
return {"messages": []}
# System information
@app.get("/system/info")
async def get_system_info():
"""Get system information"""
return {
"service": "YourAI",
"version": "4.0.0",
"features": {
"persistent_lightrag": True,
"vercel_only_storage": True,
"jwt_authentication": True,
"conversation_memory": True,
"custom_ai_support": True,
"zero_token_waste": True,
"graph_rag": True
},
"models": {
"llm": "@cf/meta/llama-3.2-3b-instruct",
"embedding": "@cf/baai/bge-m3"
},
"storage": {
"database": "PostgreSQL (Vercel)",
"cache": "Redis (Upstash)",
"files": "Vercel Blob",
"rag_persistence": "Vercel Blob + PostgreSQL",
"working_directory": "Memory only (no disk persistence)"
}
}
# System status
@app.get("/system/status")
async def get_system_status():
"""Get system status"""
status = {
"status": "healthy",
"components": {
"lightrag_manager": lightrag_manager is not None,
"database_pool": lightrag_manager.db.pool is not None if lightrag_manager else False,
"redis_connection": lightrag_manager.db.redis is not None if lightrag_manager else False,
"vercel_blob": bool(os.getenv("BLOB_READ_WRITE_TOKEN")),
"cloudflare_ai": bool(os.getenv("CLOUDFLARE_API_KEY"))
}
}
# Overall health
all_healthy = all(status["components"].values())
status["status"] = "healthy" if all_healthy else "unhealthy"
if lightrag_manager:
status["memory"] = {
"cached_rag_instances": len(lightrag_manager.rag_instances),
"processing_locks": len(lightrag_manager.processing_locks),
"conversation_memory": len(lightrag_manager.conversation_memory)
}
return status
# Test RAG persistence
@app.get("/test/rag-persistence")
async def test_rag_persistence():
"""Test RAG persistence functionality"""
if not lightrag_manager:
return {"error": "LightRAG manager not initialized"}
try:
# Test fire safety RAG
fire_safety_rag = await lightrag_manager.get_or_create_rag_instance("fire-safety")
# Test query
from lightrag import QueryParam
test_response = await fire_safety_rag.aquery(
"What are the fire exit requirements?",
QueryParam(mode="hybrid")
)
return {
"message": "RAG persistence test successful",
"fire_safety_rag_loaded": fire_safety_rag is not None,
"test_query_response": test_response[:200] + "..." if len(test_response) > 200 else test_response,
"cached_instances": len(lightrag_manager.rag_instances),
"vercel_only_storage": True,
"zero_token_waste": True
}
except Exception as e:
logger.error(f"RAG persistence test failed: {e}")
return {"error": str(e), "vercel_only_storage": False}
# Legacy endpoints for compatibility
@app.post("/ask", response_model=QuestionResponse)
async def ask_legacy(
request: QuestionRequest,
current_user: dict = Depends(get_current_user)
):
"""Legacy endpoint - redirects to fire-safety chat"""
return await chat_fire_safety(request, current_user)
@app.get("/modes")
async def get_modes():
"""Get available query modes"""
return {
"modes": ["hybrid", "local", "global", "naive"],
"default": "hybrid",
"descriptions": {
"hybrid": "Combines local and global knowledge",
"local": "Uses local knowledge base",
"global": "Uses global knowledge",
"naive": "Simple retrieval"
}
}
@app.get("/examples")
async def get_examples():
"""Get example queries"""
return {
"fire_safety": [
"What are the fire exit requirements for a commercial building?",
"How many fire extinguishers are needed in an office space?",
"What is the maximum travel distance to an exit?",
"What are the requirements for emergency lighting?",
"How often should fire safety equipment be inspected?"
],
"general": [
"How do I create a presentation?",
"What is machine learning?",
"Explain quantum computing",
"Help me plan a project timeline",
"What are the best practices for remote work?"
]
}
async def query_rag_with_fallback(
lightrag_manager: PersistentLightRAGManager,
ai_type: str,
question: str,
conversation_id: str,
user_id: str,
ai_id: Optional[str] = None,
preferred_mode: str = "hybrid"
) -> Dict[str, Any]:
"""Query RAG with automatic fallback to working modes"""
# Try modes in order of preference
fallback_modes = ["hybrid", "local", "global", "naive"]
# Start with user's preferred mode
if preferred_mode in fallback_modes:
fallback_modes.remove(preferred_mode)
fallback_modes.insert(0, preferred_mode)
last_error = None
for mode in fallback_modes:
try:
logger.info(f"🔍 Trying {mode} mode for {ai_type} query")
response = await lightrag_manager.query_with_memory(
ai_type=ai_type,
question=question,
conversation_id=conversation_id,
user_id=user_id,
ai_id=ai_id,
mode=mode
)
# Check if response is valid (not a "Sorry" message)
if response and len(response.strip()) > 20 and not response.startswith("Sorry"):
logger.info(f"✅ {mode} mode worked for {ai_type}")
return {
"answer": response,
"mode": mode,
"status": "success",
"fallback_used": mode != preferred_mode
}
else:
logger.warning(f"⚠️ {mode} mode returned empty/error response")
last_error = f"{mode} mode returned: {response[:100]}..."
except Exception as e:
logger.warning(f"⚠️ {mode} mode failed: {e}")
last_error = str(e)
continue
# If all modes fail, return error
raise HTTPException(
status_code=500,
detail=f"All query modes failed. Last error: {last_error}"
)
@app.get("/debug/rag-status")
async def debug_rag_status():
"""Debug endpoint to check RAG instance status"""
if not lightrag_manager:
return {"error": "LightRAG manager not initialized"}
try:
status = {}
# Check cached instances
status["cached_instances"] = list(lightrag_manager.rag_instances.keys())
# Check database instances
async with lightrag_manager.db.pool.acquire() as conn:
db_instances = await conn.fetch("""
SELECT id, ai_type, user_id, ai_id, name, total_chunks, total_tokens, status, created_at,
graph_blob_url, vector_blob_url, config_blob_url
FROM rag_instances
WHERE status = 'active'
ORDER BY created_at DESC
""")
status["database_instances"] = []
for instance in db_instances:
# Check storage data for each instance
storage_info = await conn.fetchrow("""
SELECT filename, file_size, token_count, processing_status
FROM knowledge_files
WHERE rag_instance_id = $1 AND filename = 'lightrag_storage.json'
LIMIT 1
""", instance['id'])
status["database_instances"].append({
"id": str(instance['id']),
"ai_type": instance['ai_type'],
"user_id": instance['user_id'],
"ai_id": instance['ai_id'],
"name": instance['name'],
"total_chunks": instance['total_chunks'],
"total_tokens": instance['total_tokens'],
"status": instance['status'],
"created_at": instance['created_at'].isoformat(),
"has_blob_urls": bool(instance['graph_blob_url'] and instance['vector_blob_url']),
"blob_type": "vercel" if (instance['graph_blob_url'] and not instance['graph_blob_url'].startswith(
'database://')) else "database",
"has_storage": storage_info is not None,
"storage_size": storage_info['file_size'] if storage_info else 0,
"storage_tokens": storage_info['token_count'] if storage_info else 0,
"storage_status": storage_info['processing_status'] if storage_info else "none"
})
return status
except Exception as e:
return {"error": str(e)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)