safetyAI / app.py
al1kss's picture
Update app.py
ad4358c verified
raw
history blame
29.7 kB
import os
import json
import uuid
import logging
from datetime import datetime, timedelta
from pathlib import Path
from typing import List, Optional, Dict, Any
import bcrypt
from fastapi import FastAPI, HTTPException, Depends, UploadFile, File, Form, Security
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, EmailStr
import jwt
import hashlib
import secrets
# Import the corrected persistent LightRAG manager
from lightrag_manager import (
PersistentLightRAGManager,
CloudflareWorker,
initialize_lightrag_manager,
get_lightrag_manager,
validate_environment,
EnvironmentError,
RAGConfig
)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Environment variables
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", secrets.token_urlsafe(32))
JWT_ALGORITHM = "HS256"
JWT_EXPIRATION_HOURS = 24 * 7 # 1 week
# File upload settings
MAX_UPLOAD_SIZE = int(os.getenv("MAX_UPLOAD_SIZE", "10485760")) # 10MB
# Create FastAPI app
app = FastAPI(
title="YourAI - Complete LightRAG Backend",
version="4.0.0",
description="Complete LightRAG system with Vercel-only persistence and JWT authentication"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Security
security = HTTPBearer()
# Pydantic models
class UserRegisterRequest(BaseModel):
email: EmailStr
name: str
class UserLoginRequest(BaseModel):
email: EmailStr
class QuestionRequest(BaseModel):
question: str
mode: Optional[str] = "hybrid"
conversation_id: Optional[str] = None
class QuestionResponse(BaseModel):
answer: str
mode: str
status: str
conversation_id: Optional[str] = None
class CustomAIRequest(BaseModel):
name: str
description: str
class FileUploadResponse(BaseModel):
filename: str
size: int
message: str
class UserResponse(BaseModel):
id: str
email: str
name: str
created_at: str
class AuthResponse(BaseModel):
token: str
user: UserResponse
message: str
# Global variables
lightrag_manager: Optional[PersistentLightRAGManager] = None
# JWT utilities
def create_jwt_token(user_data: dict) -> str:
"""Create JWT token for user"""
payload = {
"user_id": user_data["id"],
"email": user_data["email"],
"name": user_data["name"],
"exp": datetime.utcnow() + timedelta(hours=JWT_EXPIRATION_HOURS),
"iat": datetime.utcnow(),
"type": "access"
}
return jwt.encode(payload, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
def verify_jwt_token(token: str) -> dict:
"""Verify JWT token and return payload"""
try:
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")
def hash_email(email: str) -> str:
"""Create a hash of email for user ID"""
return hashlib.md5(email.encode()).hexdigest()
# Authentication dependency
async def get_current_user(credentials: HTTPAuthorizationCredentials = Security(security)) -> dict:
"""Get current user from JWT token"""
token = credentials.credentials
payload = verify_jwt_token(token)
return {
"id": payload["user_id"],
"email": payload["email"],
"name": payload["name"]
}
# Optional authentication (for endpoints that work with or without auth)
async def get_current_user_optional(authorization: Optional[str] = None) -> Optional[dict]:
"""Get current user if authenticated, otherwise return None"""
if not authorization:
return None
try:
token = authorization.replace("Bearer ", "")
payload = verify_jwt_token(token)
return {
"id": payload["user_id"],
"email": payload["email"],
"name": payload["name"]
}
except:
return None
# Startup event
@app.on_event("startup")
async def startup_event():
global lightrag_manager
logger.info("🚀 Starting up YourAI with Complete LightRAG Backend...")
try:
# Validate environment variables
validate_environment()
# Initialize LightRAG manager
lightrag_manager = await initialize_lightrag_manager()
# Test database connection
async with lightrag_manager.db.pool.acquire() as conn:
result = await conn.fetchval("SELECT 1")
logger.info(f"✅ Database connection successful: {result}")
# Clean up duplicate fire safety RAG instances first
logger.info("🧹 Cleaning up duplicate RAG instances...")
active_instance = await lightrag_manager.db.cleanup_duplicate_rag_instances("fire-safety", keep_latest=True)
if active_instance:
logger.info(f"🎯 Using active fire safety RAG: {active_instance['name']} (ID: {active_instance['id']})")
# Create RAG config
config = RAGConfig(
ai_type="fire-safety",
user_id=None,
ai_id=None,
name="Fire Safety Expert",
description="Expert in fire safety regulations and procedures"
)
# Try to load existing RAG from database
logger.info("🔍 Attempting to load existing fire safety RAG from DATABASE...")
try:
rag_instance = await lightrag_manager._load_from_database(config)
if rag_instance:
# Cache the loaded instance
lightrag_manager.rag_instances[config.get_cache_key()] = rag_instance
logger.info("🎉 SUCCESS: Fire safety RAG loaded from database - NO RECREATION NEEDED")
else:
logger.info("🔧 Database loading failed - creating new RAG instance")
rag_instance = await lightrag_manager.get_or_create_rag_instance(
"fire-safety",
name="Fire Safety Expert",
description="Expert in fire safety regulations and procedures"
)
logger.info("✅ Created new fire safety RAG instance")
except Exception as e:
logger.error(f"❌ Error during fire safety RAG initialization: {e}")
# Create new as fallback
logger.info("🔧 Creating NEW fire safety RAG (fallback due to error)")
await lightrag_manager.get_or_create_rag_instance(
"fire-safety",
name="Fire Safety Expert",
description="Expert in fire safety regulations and procedures"
)
logger.info("✅ Complete LightRAG system initialized successfully")
logger.info(f"🎯 Final cached RAG instances: {list(lightrag_manager.rag_instances.keys())}")
except EnvironmentError as e:
logger.error(f"❌ Environment validation failed: {e}")
raise
except Exception as e:
logger.error(f"❌ Startup failed: {e}")
raise
# Shutdown event
@app.on_event("shutdown")
async def shutdown_event():
global lightrag_manager
if lightrag_manager:
await lightrag_manager.cleanup()
logger.info("Shutdown complete")
# Root endpoint
@app.get("/")
async def root():
return {
"message": "YourAI - Complete LightRAG Backend API",
"version": "4.0.0",
"status": "running",
"features": {
"persistent_lightrag": True,
"vercel_only_storage": True,
"jwt_authentication": True,
"conversation_memory": True,
"custom_ai_support": True,
"zero_token_waste": True
}
}
# Health check
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"lightrag_initialized": lightrag_manager is not None,
"vercel_storage": True,
"database_connected": lightrag_manager.db.pool is not None if lightrag_manager else False,
"redis_connected": lightrag_manager.db.redis is not None if lightrag_manager else False,
"environment_validated": True
}
# User registration
@app.post("/auth/register", response_model=AuthResponse)
async def register_user(request: UserRegisterRequest):
"""Register a new user"""
if not lightrag_manager:
raise HTTPException(status_code=503, detail="LightRAG system not initialized")
try:
# Check if user already exists
async with lightrag_manager.db.pool.acquire() as conn:
existing_user = await conn.fetchrow("""
SELECT id FROM users WHERE email = $1
""", request.email)
if existing_user:
raise HTTPException(status_code=400, detail="User already exists")
# Hash password
password_hash = bcrypt.hashpw(request.password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8')
# Create user in database
user_id = await conn.fetchval("""
INSERT INTO users (id, email, name, password_hash, hashed_email, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, NOW(), NOW())
RETURNING id
""", str(uuid.uuid4()), request.email, request.name, password_hash,
hashlib.md5(request.email.encode()).hexdigest())
user = {
"id": str(user_id),
"email": request.email,
"name": request.name,
"created_at": datetime.now().isoformat()
}
# Create JWT token
token = create_jwt_token(user)
logger.info(f"User registered: {request.email}")
return AuthResponse(
message="User created successfully",
user=UserResponse(**user),
token=token
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Registration failed: {e}")
raise HTTPException(status_code=500, detail=f"Registration failed: {str(e)}")
@app.post("/auth/login", response_model=AuthResponse)
async def login_user(request: UserLoginRequest):
"""Login user"""
if not lightrag_manager:
raise HTTPException(status_code=503, detail="LightRAG system not initialized")
try:
# Get user from database
async with lightrag_manager.db.pool.acquire() as conn:
user_record = await conn.fetchrow("""
SELECT id, email, name, password_hash, created_at
FROM users
WHERE email = $1 AND is_active = true
""", request.email)
if not user_record:
raise HTTPException(status_code=401, detail="Invalid email or password")
# Verify password
if not bcrypt.checkpw(request.password.encode('utf-8'), user_record['password_hash'].encode('utf-8')):
raise HTTPException(status_code=401, detail="Invalid email or password")
user = {
"id": str(user_record['id']),
"email": user_record['email'],
"name": user_record['name'],
"created_at": user_record['created_at'].isoformat()
}
# Create JWT token
token = create_jwt_token(user)
logger.info(f"User logged in: {request.email}")
return AuthResponse(
token=token,
user=UserResponse(**user),
message="Login successful"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Login failed: {e}")
raise HTTPException(status_code=500, detail=f"Login failed: {str(e)}")
# Fire Safety Chat
@app.post("/chat/fire-safety", response_model=QuestionResponse)
async def chat_fire_safety(
request: QuestionRequest,
current_user: dict = Depends(get_current_user)
):
"""Chat with fire safety AI"""
if not lightrag_manager:
raise HTTPException(status_code=503, detail="LightRAG system not initialized")
try:
# Generate conversation ID if not provided
conversation_id = request.conversation_id or str(uuid.uuid4())
# Query with persistent LightRAG
response = await lightrag_manager.query_with_memory(
ai_type="fire-safety",
question=request.question,
conversation_id=conversation_id,
user_id=current_user["id"],
mode=request.mode or "hybrid"
)
return QuestionResponse(
answer=response,
mode=request.mode or "hybrid",
status="success",
conversation_id=conversation_id
)
except Exception as e:
logger.error(f"Fire safety chat error: {e}")
raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}")
# General Chat
@app.post("/chat/general", response_model=QuestionResponse)
async def chat_general(
request: QuestionRequest,
current_user: dict = Depends(get_current_user)
):
"""General AI chat"""
if not lightrag_manager:
raise HTTPException(status_code=503, detail="LightRAG system not initialized")
try:
# Generate conversation ID if not provided
conversation_id = request.conversation_id or str(uuid.uuid4())
# Use Cloudflare worker directly for general chat
system_prompt = """You are a helpful general AI assistant. Provide accurate, helpful, and engaging responses."""
response = await lightrag_manager.cloudflare_worker.query(
request.question,
system_prompt
)
# Save conversation
await lightrag_manager.db.save_conversation_message(
conversation_id, "user", request.question, {
"user_id": current_user["id"],
"ai_type": "general"
}
)
await lightrag_manager.db.save_conversation_message(
conversation_id, "assistant", response, {
"mode": request.mode or "direct",
"ai_type": "general",
"user_id": current_user["id"]
}
)
return QuestionResponse(
answer=response,
mode=request.mode or "direct",
status="success",
conversation_id=conversation_id
)
except Exception as e:
logger.error(f"General chat error: {e}")
raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}")
# Custom AI Chat
@app.post("/chat/custom/{ai_id}", response_model=QuestionResponse)
async def chat_custom_ai(
ai_id: str,
request: QuestionRequest,
current_user: dict = Depends(get_current_user)
):
"""Chat with custom AI"""
if not lightrag_manager:
raise HTTPException(status_code=503, detail="LightRAG system not initialized")
try:
# Generate conversation ID if not provided
conversation_id = request.conversation_id or str(uuid.uuid4())
# Query with persistent LightRAG
response = await lightrag_manager.query_with_memory(
ai_type="custom",
question=request.question,
conversation_id=conversation_id,
user_id=current_user["id"],
ai_id=ai_id,
mode=request.mode or "hybrid"
)
return QuestionResponse(
answer=response,
mode=request.mode or "hybrid",
status="success",
conversation_id=conversation_id
)
except Exception as e:
logger.error(f"Custom AI chat error: {e}")
raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}")
# File upload endpoint
@app.post("/upload-files", response_model=List[FileUploadResponse])
async def upload_files(
files: List[UploadFile] = File(...),
current_user: dict = Depends(get_current_user)
):
"""Upload files for custom AI"""
uploaded_files = []
allowed_extensions = {'.txt', '.md', '.json', '.pdf', '.docx'}
for file in files:
if file.size > MAX_UPLOAD_SIZE:
raise HTTPException(
status_code=413,
detail=f"File {file.filename} too large. Max size: {MAX_UPLOAD_SIZE} bytes"
)
file_extension = Path(file.filename).suffix.lower()
if file_extension not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"File type {file_extension} not allowed. Allowed: {allowed_extensions}"
)
# Read file content
content = await file.read()
# Process content based on file type
if file_extension == '.pdf':
# Basic PDF processing (in production, use proper PDF library)
text_content = content.decode('utf-8', errors='ignore')
else:
text_content = content.decode('utf-8', errors='ignore')
# Store file temporarily in memory
if not hasattr(upload_files, 'temp_storage'):
upload_files.temp_storage = {}
if current_user["id"] not in upload_files.temp_storage:
upload_files.temp_storage[current_user["id"]] = []
upload_files.temp_storage[current_user["id"]].append({
"filename": file.filename,
"content": text_content,
"type": file_extension.lstrip('.'),
"size": len(content)
})
uploaded_files.append(FileUploadResponse(
filename=file.filename,
size=len(content),
message="Uploaded successfully"
))
return uploaded_files
# Create custom AI
@app.post("/create-custom-ai")
async def create_custom_ai(
ai_data: CustomAIRequest,
current_user: dict = Depends(get_current_user)
):
"""Create custom AI with uploaded files"""
if not lightrag_manager:
raise HTTPException(status_code=503, detail="LightRAG system not initialized")
# Get uploaded files from temporary storage
if not hasattr(upload_files, 'temp_storage') or current_user["id"] not in upload_files.temp_storage:
raise HTTPException(status_code=400, detail="No files uploaded. Please upload knowledge files first.")
uploaded_files = upload_files.temp_storage[current_user["id"]]
if not uploaded_files:
raise HTTPException(status_code=400, detail="No files found. Please upload knowledge files first.")
try:
# Create custom AI
ai_id = await lightrag_manager.create_custom_ai(
user_id=current_user["id"],
ai_name=ai_data.name,
description=ai_data.description,
uploaded_files=uploaded_files
)
# Clear temporary storage
upload_files.temp_storage[current_user["id"]] = []
ai_info = {
"id": ai_id,
"name": ai_data.name,
"description": ai_data.description,
"created_at": datetime.now().isoformat(),
"files_count": len(uploaded_files),
"persistent_storage": True,
"vercel_only": True
}
return {
"ai_id": ai_id,
"message": "Custom AI created successfully with Vercel-only persistence",
"ai_info": ai_info
}
except Exception as e:
logger.error(f"Error creating custom AI: {e}")
raise HTTPException(status_code=500, detail=f"Failed to create custom AI: {str(e)}")
# Get user's conversations
@app.get("/conversations")
async def get_conversations(current_user: dict = Depends(get_current_user)):
"""Get user's conversations"""
if not lightrag_manager:
return {"conversations": []}
try:
# Get actual conversations from database
async with lightrag_manager.db.pool.acquire() as conn:
conversations = await conn.fetch("""
SELECT id, title, ai_type, ai_id, created_at, updated_at
FROM conversations
WHERE user_id = $1 AND is_active = true
ORDER BY updated_at DESC
LIMIT 50
""", current_user["id"])
formatted_conversations = []
for conv in conversations:
formatted_conversations.append({
"id": str(conv['id']),
"title": conv['title'] or "New Conversation",
"ai_type": conv['ai_type'],
"ai_id": conv['ai_id'],
"last_message": "Click to view conversation",
"updated_at": conv['updated_at'].isoformat()
})
return {"conversations": formatted_conversations}
except Exception as e:
logger.error(f"Error getting conversations: {e}")
return {"conversations": []}
# Get conversation messages
@app.get("/conversations/{conversation_id}")
async def get_conversation_messages(
conversation_id: str,
current_user: dict = Depends(get_current_user)
):
"""Get messages for a specific conversation"""
try:
messages = await lightrag_manager.db.get_conversation_messages(conversation_id)
formatted_messages = [
{
"id": str(msg.get("id", uuid.uuid4())),
"role": msg["role"],
"content": msg["content"],
"created_at": msg.get("created_at", datetime.now().isoformat())
}
for msg in messages
]
return {"messages": formatted_messages}
except Exception as e:
logger.error(f"Error getting conversation messages: {e}")
return {"messages": []}
# System information
@app.get("/system/info")
async def get_system_info():
"""Get system information"""
return {
"service": "YourAI",
"version": "4.0.0",
"features": {
"persistent_lightrag": True,
"vercel_only_storage": True,
"jwt_authentication": True,
"conversation_memory": True,
"custom_ai_support": True,
"zero_token_waste": True,
"graph_rag": True
},
"models": {
"llm": "@cf/meta/llama-3.2-3b-instruct",
"embedding": "@cf/baai/bge-m3"
},
"storage": {
"database": "PostgreSQL (Vercel)",
"cache": "Redis (Upstash)",
"files": "Vercel Blob",
"rag_persistence": "Vercel Blob + PostgreSQL",
"working_directory": "Memory only (no disk persistence)"
}
}
# System status
@app.get("/system/status")
async def get_system_status():
"""Get system status"""
status = {
"status": "healthy",
"components": {
"lightrag_manager": lightrag_manager is not None,
"database_pool": lightrag_manager.db.pool is not None if lightrag_manager else False,
"redis_connection": lightrag_manager.db.redis is not None if lightrag_manager else False,
"vercel_blob": bool(os.getenv("BLOB_READ_WRITE_TOKEN")),
"cloudflare_ai": bool(os.getenv("CLOUDFLARE_API_KEY"))
}
}
# Overall health
all_healthy = all(status["components"].values())
status["status"] = "healthy" if all_healthy else "unhealthy"
if lightrag_manager:
status["memory"] = {
"cached_rag_instances": len(lightrag_manager.rag_instances),
"processing_locks": len(lightrag_manager.processing_locks),
"conversation_memory": len(lightrag_manager.conversation_memory)
}
return status
# Test RAG persistence
@app.get("/test/rag-persistence")
async def test_rag_persistence():
"""Test RAG persistence functionality"""
if not lightrag_manager:
return {"error": "LightRAG manager not initialized"}
try:
# Test fire safety RAG
fire_safety_rag = await lightrag_manager.get_or_create_rag_instance("fire-safety")
# Test query
from lightrag import QueryParam
test_response = await fire_safety_rag.aquery(
"What are the fire exit requirements?",
QueryParam(mode="hybrid")
)
return {
"message": "RAG persistence test successful",
"fire_safety_rag_loaded": fire_safety_rag is not None,
"test_query_response": test_response[:200] + "..." if len(test_response) > 200 else test_response,
"cached_instances": len(lightrag_manager.rag_instances),
"vercel_only_storage": True,
"zero_token_waste": True
}
except Exception as e:
logger.error(f"RAG persistence test failed: {e}")
return {"error": str(e), "vercel_only_storage": False}
# Legacy endpoints for compatibility
@app.post("/ask", response_model=QuestionResponse)
async def ask_legacy(
request: QuestionRequest,
current_user: dict = Depends(get_current_user)
):
"""Legacy endpoint - redirects to fire-safety chat"""
return await chat_fire_safety(request, current_user)
@app.get("/modes")
async def get_modes():
"""Get available query modes"""
return {
"modes": ["hybrid", "local", "global", "naive"],
"default": "hybrid",
"descriptions": {
"hybrid": "Combines local and global knowledge",
"local": "Uses local knowledge base",
"global": "Uses global knowledge",
"naive": "Simple retrieval"
}
}
@app.get("/examples")
async def get_examples():
"""Get example queries"""
return {
"fire_safety": [
"What are the fire exit requirements for a commercial building?",
"How many fire extinguishers are needed in an office space?",
"What is the maximum travel distance to an exit?",
"What are the requirements for emergency lighting?",
"How often should fire safety equipment be inspected?"
],
"general": [
"How do I create a presentation?",
"What is machine learning?",
"Explain quantum computing",
"Help me plan a project timeline",
"What are the best practices for remote work?"
]
}
@app.get("/debug/rag-status")
async def debug_rag_status():
"""Debug endpoint to check RAG instance status"""
if not lightrag_manager:
return {"error": "LightRAG manager not initialized"}
try:
status = {}
# Check cached instances
status["cached_instances"] = list(lightrag_manager.rag_instances.keys())
# Check database instances
async with lightrag_manager.db.pool.acquire() as conn:
db_instances = await conn.fetch("""
SELECT id, ai_type, user_id, ai_id, name, total_chunks, total_tokens, status, created_at,
graph_blob_url, vector_blob_url, config_blob_url
FROM rag_instances
WHERE status = 'active'
ORDER BY created_at DESC
""")
status["database_instances"] = []
for instance in db_instances:
# Check storage data for each instance
storage_info = await conn.fetchrow("""
SELECT filename, file_size, token_count, processing_status
FROM knowledge_files
WHERE rag_instance_id = $1 AND filename = 'lightrag_storage.json'
LIMIT 1
""", instance['id'])
status["database_instances"].append({
"id": str(instance['id']),
"ai_type": instance['ai_type'],
"user_id": instance['user_id'],
"ai_id": instance['ai_id'],
"name": instance['name'],
"total_chunks": instance['total_chunks'],
"total_tokens": instance['total_tokens'],
"status": instance['status'],
"created_at": instance['created_at'].isoformat(),
"has_blob_urls": bool(instance['graph_blob_url'] and instance['vector_blob_url']),
"blob_type": "vercel" if (instance['graph_blob_url'] and not instance['graph_blob_url'].startswith('database://')) else "database",
"has_storage": storage_info is not None,
"storage_size": storage_info['file_size'] if storage_info else 0,
"storage_tokens": storage_info['token_count'] if storage_info else 0,
"storage_status": storage_info['processing_status'] if storage_info else "none"
})
return status
except Exception as e:
return {"error": str(e)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)