Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from typing import List, Optional, Dict, Any | |
| import uvicorn | |
| import logging | |
| import time | |
| import os | |
| import asyncio | |
| from contextlib import asynccontextmanager | |
| from pathlib import Path | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Global RAG system instance | |
| rag_system = None | |
| system_loading = False | |
| system_load_error = None | |
| async def lifespan(app: FastAPI): | |
| # Startup | |
| global rag_system, system_loading, system_load_error | |
| logger.info("Starting Text-to-SQL RAG API with CodeLlama for HF Spaces...") | |
| # Start system loading in background | |
| system_loading = True | |
| system_load_error = None | |
| try: | |
| # Import here to avoid startup delays | |
| from rag_system import VectorStore, SQLRetriever, PromptEngine, SQLGenerator, DataProcessor | |
| # Initialize RAG system components | |
| logger.info("Initializing RAG system components...") | |
| # Initialize vector store | |
| logger.info("Initializing vector store...") | |
| vector_store = VectorStore() | |
| # Initialize SQL retriever | |
| logger.info("Initializing SQL retriever...") | |
| sql_retriever = SQLRetriever(vector_store) | |
| # Initialize prompt engine | |
| logger.info("Initializing prompt engine...") | |
| prompt_engine = PromptEngine() | |
| # Initialize SQL generator (with CodeLlama as primary) | |
| logger.info("Initializing SQL generator with CodeLlama...") | |
| sql_generator = SQLGenerator(sql_retriever, prompt_engine) | |
| # Initialize data processor | |
| logger.info("Initializing data processor...") | |
| data_processor = DataProcessor() | |
| # Create RAG system object | |
| rag_system = { | |
| "vector_store": vector_store, | |
| "sql_retriever": sql_retriever, | |
| "prompt_engine": prompt_engine, | |
| "sql_generator": sql_generator, | |
| "data_processor": data_processor | |
| } | |
| # Load or create sample data | |
| logger.info("Loading sample data...") | |
| await load_or_create_sample_data(data_processor, vector_store) | |
| logger.info("All RAG system components initialized successfully!") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize RAG system: {str(e)}") | |
| system_load_error = str(e) | |
| finally: | |
| system_loading = False | |
| yield | |
| # Shutdown | |
| logger.info("Shutting down Text-to-SQL RAG API...") | |
| async def load_or_create_sample_data(data_processor, vector_store): | |
| """Load existing data or create sample dataset.""" | |
| try: | |
| # Try to load existing processed data | |
| examples = data_processor.load_processed_data() | |
| if examples: | |
| logger.info(f"Loaded {len(examples)} existing examples") | |
| # Add to vector store | |
| vector_store.add_examples(examples) | |
| else: | |
| # Create sample dataset | |
| logger.info("Creating sample dataset...") | |
| sample_data = data_processor.create_sample_dataset() | |
| vector_store.add_examples(sample_data) | |
| logger.info(f"Added {len(sample_data)} sample examples to vector store") | |
| except Exception as e: | |
| logger.warning(f"Could not load sample data: {e}") | |
| # Create minimal sample data | |
| try: | |
| sample_data = data_processor.create_sample_dataset() | |
| vector_store.add_examples(sample_data) | |
| logger.info(f"Added {len(sample_data)} sample examples to vector store") | |
| except Exception as e2: | |
| logger.error(f"Failed to create sample data: {e2}") | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="Text-to-SQL RAG API with CodeLlama", | |
| description="Advanced API for converting natural language questions to SQL queries using RAG and CodeLlama", | |
| version="2.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Pydantic models for request/response | |
| class SQLRequest(BaseModel): | |
| question: str | |
| table_headers: List[str] | |
| class SQLResponse(BaseModel): | |
| question: str | |
| table_headers: List[str] | |
| sql_query: str | |
| model_used: str | |
| processing_time: float | |
| retrieved_examples: List[Dict[str, Any]] | |
| status: str | |
| class BatchRequest(BaseModel): | |
| queries: List[SQLRequest] | |
| class BatchResponse(BaseModel): | |
| results: List[SQLResponse] | |
| total_queries: int | |
| successful_queries: int | |
| class HealthResponse(BaseModel): | |
| status: str | |
| system_loaded: bool | |
| system_loading: bool | |
| system_error: Optional[str] = None | |
| model_info: Optional[Dict[str, Any]] = None | |
| timestamp: float | |
| async def root(): | |
| """Serve the main HTML interface""" | |
| try: | |
| with open("index.html", "r", encoding="utf-8") as f: | |
| return HTMLResponse(content=f.read()) | |
| except FileNotFoundError: | |
| return HTMLResponse(content=""" | |
| <html> | |
| <body> | |
| <h1>Text-to-SQL RAG API with CodeLlama</h1> | |
| <p>Advanced SQL generation using RAG and CodeLlama models</p> | |
| <p>index.html not found. Please ensure the file exists in the same directory.</p> | |
| </body> | |
| </html> | |
| """) | |
| async def api_info(): | |
| """API information endpoint""" | |
| return { | |
| "message": "Text-to-SQL RAG API with CodeLlama", | |
| "version": "2.0.0", | |
| "features": [ | |
| "RAG-enhanced SQL generation", | |
| "CodeLlama as primary model", | |
| "Vector-based example retrieval", | |
| "Advanced prompt engineering" | |
| ], | |
| "endpoints": { | |
| "/": "GET - Web interface", | |
| "/api": "GET - API information", | |
| "/predict": "POST - Generate SQL from single question", | |
| "/batch": "POST - Generate SQL from multiple questions", | |
| "/health": "GET - Health check", | |
| "/docs": "GET - API documentation" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| global rag_system, system_loading, system_load_error | |
| model_info = None | |
| if rag_system and "sql_generator" in rag_system: | |
| try: | |
| model_info = rag_system["sql_generator"].get_model_info() | |
| except Exception as e: | |
| logger.warning(f"Could not get model info: {e}") | |
| return HealthResponse( | |
| status="healthy" if rag_system and not system_loading else "unhealthy", | |
| system_loaded=rag_system is not None, | |
| system_loading=system_loading, | |
| system_error=system_load_error, | |
| model_info=model_info, | |
| timestamp=time.time() | |
| ) | |
| async def predict_sql(request: SQLRequest): | |
| """ | |
| Generate SQL query from a natural language question using RAG and CodeLlama | |
| Args: | |
| request: SQLRequest containing question and table headers | |
| Returns: | |
| SQLResponse with generated SQL query and metadata | |
| """ | |
| global rag_system, system_loading, system_load_error | |
| if system_loading: | |
| raise HTTPException(status_code=503, detail="System is still loading, please try again in a few minutes") | |
| if rag_system is None: | |
| error_msg = system_load_error or "RAG system not loaded" | |
| raise HTTPException(status_code=503, detail=f"System not available: {error_msg}") | |
| start_time = time.time() | |
| try: | |
| # Generate SQL using RAG system | |
| result = rag_system["sql_generator"].generate_sql( | |
| question=request.question, | |
| table_headers=request.table_headers | |
| ) | |
| processing_time = time.time() - start_time | |
| return SQLResponse( | |
| question=request.question, | |
| table_headers=request.table_headers, | |
| sql_query=result["sql_query"], | |
| model_used=result["model_used"], | |
| processing_time=processing_time, | |
| retrieved_examples=result["retrieved_examples"], | |
| status=result["status"] | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error generating SQL: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error generating SQL: {str(e)}") | |
| async def batch_predict(request: BatchRequest): | |
| """ | |
| Generate SQL queries from multiple questions using RAG and CodeLlama | |
| Args: | |
| request: BatchRequest containing list of questions and table headers | |
| Returns: | |
| BatchResponse with generated SQL queries | |
| """ | |
| global rag_system, system_loading, system_load_error | |
| if system_loading: | |
| raise HTTPException(status_code=503, detail="System is still loading, please try again in a few minutes") | |
| if rag_system is None: | |
| error_msg = system_load_error or "RAG system not loaded" | |
| raise HTTPException(status_code=503, detail=f"System not available: {error_msg}") | |
| start_time = time.time() | |
| try: | |
| results = [] | |
| successful_count = 0 | |
| for query in request.queries: | |
| try: | |
| result = rag_system["sql_generator"].generate_sql( | |
| question=query.question, | |
| table_headers=query.table_headers | |
| ) | |
| sql_response = SQLResponse( | |
| question=query.question, | |
| table_headers=query.table_headers, | |
| sql_query=result["sql_query"], | |
| model_used=result["model_used"], | |
| processing_time=result["processing_time"], | |
| retrieved_examples=result["retrieved_examples"], | |
| status=result["status"] | |
| ) | |
| results.append(sql_response) | |
| if result["status"] == "success": | |
| successful_count += 1 | |
| except Exception as e: | |
| logger.error(f"Error processing query '{query.question}': {str(e)}") | |
| # Add error response | |
| error_response = SQLResponse( | |
| question=query.question, | |
| table_headers=query.table_headers, | |
| sql_query="", | |
| model_used="none", | |
| processing_time=0.0, | |
| retrieved_examples=[], | |
| status="error" | |
| ) | |
| results.append(error_response) | |
| total_time = time.time() - start_time | |
| return BatchResponse( | |
| results=results, | |
| total_queries=len(request.queries), | |
| successful_queries=successful_count | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in batch processing: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error in batch processing: {str(e)}") | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |