Spaces:
Running
Running
| import torch | |
| import numpy as np | |
| import pickle | |
| from transformers import AutoTokenizer, AutoModel | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import logging | |
| import config | |
| import os | |
| from dotenv import load_dotenv | |
| from langsmith.run_helpers import traceable | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| logger = logging.getLogger("swayam-chatbot") | |
| # Initialize Groq client with proper error handling | |
| try: | |
| from groq import Groq | |
| # Try to get API key from environment directly as a fallback | |
| api_key = config.GROQ_API_KEY or os.environ.get("GROQ_API_KEY") | |
| if not api_key: | |
| logger.warning("No Groq API key found. LLM functionality will not work.") | |
| client = None | |
| else: | |
| client = Groq(api_key=api_key) | |
| logger.info("Groq client initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Groq client: {e}") | |
| client = None | |
| # Function for mean pooling to get sentence embeddings | |
| def mean_pooling(model_output, attention_mask): | |
| token_embeddings = model_output[0] | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| # Load embeddings and model once at startup | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| tokenizer = None | |
| model = None | |
| chunks = None | |
| embeddings = None | |
| def load_resources(): | |
| """Load the embedding model and pre-computed embeddings""" | |
| global tokenizer, model, chunks, embeddings | |
| # Load model and tokenizer | |
| logger.info("Loading embedding model...") | |
| tokenizer = AutoTokenizer.from_pretrained(config.EMBEDDING_MODEL) | |
| model = AutoModel.from_pretrained(config.EMBEDDING_MODEL) | |
| model.to(device) | |
| # Create embeddings directory if it doesn't exist | |
| os.makedirs(os.path.dirname(config.CHUNK_PATH), exist_ok=True) | |
| # Load stored chunks and embeddings | |
| logger.info("Loading pre-computed embeddings...") | |
| try: | |
| with open(config.CHUNK_PATH, "rb") as f: | |
| chunks = pickle.load(f) | |
| with open(config.EMBEDDING_PATH, "rb") as f: | |
| embeddings = pickle.load(f) | |
| logger.info(f"Loaded {len(chunks)} chunks and embeddings of shape {embeddings.shape}") | |
| return True | |
| except FileNotFoundError as e: | |
| logger.error(f"Error loading embeddings: {e}") | |
| # Try downloading from cloud storage if available | |
| if config.EMBEDDINGS_CLOUD_URL: | |
| logger.info(f"Attempting to download embeddings from cloud storage...") | |
| success = download_embeddings_from_cloud() | |
| if success: | |
| return load_resources() # Try loading again after download | |
| return False | |
| def download_embeddings_from_cloud(): | |
| """Download embeddings from cloud storage""" | |
| try: | |
| import requests | |
| # Download chunks file | |
| logger.info(f"Downloading chunks from {config.CHUNKS_CLOUD_URL}") | |
| response = requests.get(config.CHUNKS_CLOUD_URL) | |
| if response.status_code == 200: | |
| os.makedirs(os.path.dirname(config.CHUNK_PATH), exist_ok=True) | |
| with open(config.CHUNK_PATH, "wb") as f: | |
| f.write(response.content) | |
| logger.info("Successfully downloaded chunks file") | |
| else: | |
| logger.error(f"Failed to download chunks: {response.status_code}") | |
| return False | |
| # Download embeddings file | |
| logger.info(f"Downloading embeddings from {config.EMBEDDINGS_CLOUD_URL}") | |
| response = requests.get(config.EMBEDDINGS_CLOUD_URL) | |
| if response.status_code == 200: | |
| with open(config.EMBEDDING_PATH, "wb") as f: | |
| f.write(response.content) | |
| logger.info("Successfully downloaded embeddings file") | |
| return True | |
| else: | |
| logger.error(f"Failed to download embeddings: {response.status_code}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error downloading embeddings: {e}") | |
| return False | |
| def is_personal_query(query): | |
| """Determine if a query is about Swayam or general knowledge""" | |
| query_lower = query.lower() | |
| # Check if query contains personal keywords | |
| for keyword in config.PERSONAL_KEYWORDS: | |
| if keyword.lower() in query_lower: | |
| logger.info(f"Query classified as PERSONAL due to keyword: {keyword}") | |
| return True | |
| logger.info("Query classified as GENERAL") | |
| return False | |
| def get_relevant_context(query, top_k=3): | |
| """Retrieve relevant context from embeddings for a given query""" | |
| if tokenizer is None or model is None: | |
| logger.error("Embedding model not loaded. Call load_resources() first.") | |
| return "" | |
| # Process query with e5 model - use "query: " prefix for better retrieval | |
| inputs = tokenizer(f"query: {query}", padding=True, truncation=True, | |
| return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Get query embedding | |
| query_embedding = mean_pooling(outputs, inputs["attention_mask"]).cpu().numpy() | |
| # Calculate similarity with all chunk embeddings | |
| similarities = cosine_similarity(query_embedding, embeddings)[0] | |
| # Get top k most similar chunks | |
| top_indices = np.argsort(similarities)[::-1][:top_k] | |
| # Combine the text from the top chunks | |
| context_parts = [] | |
| for idx in top_indices: | |
| _, chunk_text = chunks[idx] | |
| similarity = similarities[idx] | |
| if similarity > 0.2: # Only include reasonably similar chunks | |
| context_parts.append(chunk_text) | |
| logger.info(f"Including chunk with similarity: {similarity:.4f}") | |
| return "\n\n".join(context_parts) | |
| def get_llm_response(messages): | |
| """Get response from LLM using Groq API""" | |
| if client is None: | |
| logger.error("Groq client not initialized. Cannot get LLM response.") | |
| return "Sorry, I cannot access the language model at the moment. Please ensure the API key is set correctly." | |
| try: | |
| response = client.chat.completions.create( | |
| model=config.MODEL_NAME, | |
| messages=messages, | |
| temperature=0.7, | |
| max_completion_tokens=1024, | |
| top_p=1, | |
| stream=False | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| logger.error(f"Error calling Groq LLM API: {e}") | |
| return "Sorry, I encountered an error while processing your request." | |
| def generate_response(query): | |
| """Generate a response based on the query type""" | |
| if is_personal_query(query): | |
| # Personal query - use RAG approach | |
| context = get_relevant_context(query) | |
| logger.info(f"Retrieved context: {context[:200]}...") | |
| messages = [ | |
| {"role": "system", "content": config.PERSONAL_SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"Context about Swayam:\n{context}\n\nQuestion: {query}"} | |
| ] | |
| response = get_llm_response(messages) | |
| return {"response": response, "type": "personal"} | |
| else: | |
| # General query - use LLM directly | |
| messages = [ | |
| {"role": "system", "content": config.GENERAL_SYSTEM_PROMPT}, | |
| {"role": "user", "content": query} | |
| ] | |
| response = get_llm_response(messages) | |
| return {"response": response, "type": "general"} | |