import torch import numpy as np import pickle from transformers import AutoTokenizer, AutoModel from sklearn.metrics.pairwise import cosine_similarity import logging import config import os from dotenv import load_dotenv from langsmith.run_helpers import traceable # Load environment variables from .env file load_dotenv() logger = logging.getLogger("swayam-chatbot") # Initialize Groq client with proper error handling try: from groq import Groq # Try to get API key from environment directly as a fallback api_key = config.GROQ_API_KEY or os.environ.get("GROQ_API_KEY") if not api_key: logger.warning("No Groq API key found. LLM functionality will not work.") client = None else: client = Groq(api_key=api_key) logger.info("Groq client initialized successfully") except Exception as e: logger.error(f"Failed to initialize Groq client: {e}") client = None # Function for mean pooling to get sentence embeddings def mean_pooling(model_output, attention_mask): token_embeddings = model_output[0] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) # Load embeddings and model once at startup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = None model = None chunks = None embeddings = None def load_resources(): """Load the embedding model and pre-computed embeddings""" global tokenizer, model, chunks, embeddings # Load model and tokenizer logger.info("Loading embedding model...") tokenizer = AutoTokenizer.from_pretrained(config.EMBEDDING_MODEL) model = AutoModel.from_pretrained(config.EMBEDDING_MODEL) model.to(device) # Create embeddings directory if it doesn't exist os.makedirs(os.path.dirname(config.CHUNK_PATH), exist_ok=True) # Load stored chunks and embeddings logger.info("Loading pre-computed embeddings...") try: with open(config.CHUNK_PATH, "rb") as f: chunks = pickle.load(f) with open(config.EMBEDDING_PATH, "rb") as f: embeddings = pickle.load(f) logger.info(f"Loaded {len(chunks)} chunks and embeddings of shape {embeddings.shape}") return True except FileNotFoundError as e: logger.error(f"Error loading embeddings: {e}") # Try downloading from cloud storage if available if config.EMBEDDINGS_CLOUD_URL: logger.info(f"Attempting to download embeddings from cloud storage...") success = download_embeddings_from_cloud() if success: return load_resources() # Try loading again after download return False def download_embeddings_from_cloud(): """Download embeddings from cloud storage""" try: import requests # Download chunks file logger.info(f"Downloading chunks from {config.CHUNKS_CLOUD_URL}") response = requests.get(config.CHUNKS_CLOUD_URL) if response.status_code == 200: os.makedirs(os.path.dirname(config.CHUNK_PATH), exist_ok=True) with open(config.CHUNK_PATH, "wb") as f: f.write(response.content) logger.info("Successfully downloaded chunks file") else: logger.error(f"Failed to download chunks: {response.status_code}") return False # Download embeddings file logger.info(f"Downloading embeddings from {config.EMBEDDINGS_CLOUD_URL}") response = requests.get(config.EMBEDDINGS_CLOUD_URL) if response.status_code == 200: with open(config.EMBEDDING_PATH, "wb") as f: f.write(response.content) logger.info("Successfully downloaded embeddings file") return True else: logger.error(f"Failed to download embeddings: {response.status_code}") return False except Exception as e: logger.error(f"Error downloading embeddings: {e}") return False def is_personal_query(query): """Determine if a query is about Swayam or general knowledge""" query_lower = query.lower() # Check if query contains personal keywords for keyword in config.PERSONAL_KEYWORDS: if keyword.lower() in query_lower: logger.info(f"Query classified as PERSONAL due to keyword: {keyword}") return True logger.info("Query classified as GENERAL") return False @traceable(run_type="retriever", name="E5 Vector Retriever") def get_relevant_context(query, top_k=3): """Retrieve relevant context from embeddings for a given query""" if tokenizer is None or model is None: logger.error("Embedding model not loaded. Call load_resources() first.") return "" # Process query with e5 model - use "query: " prefix for better retrieval inputs = tokenizer(f"query: {query}", padding=True, truncation=True, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) # Get query embedding query_embedding = mean_pooling(outputs, inputs["attention_mask"]).cpu().numpy() # Calculate similarity with all chunk embeddings similarities = cosine_similarity(query_embedding, embeddings)[0] # Get top k most similar chunks top_indices = np.argsort(similarities)[::-1][:top_k] # Combine the text from the top chunks context_parts = [] for idx in top_indices: _, chunk_text = chunks[idx] similarity = similarities[idx] if similarity > 0.2: # Only include reasonably similar chunks context_parts.append(chunk_text) logger.info(f"Including chunk with similarity: {similarity:.4f}") return "\n\n".join(context_parts) @traceable(run_type="llm", name="Groq LLM") def get_llm_response(messages): """Get response from LLM using Groq API""" if client is None: logger.error("Groq client not initialized. Cannot get LLM response.") return "Sorry, I cannot access the language model at the moment. Please ensure the API key is set correctly." try: response = client.chat.completions.create( model=config.MODEL_NAME, messages=messages, temperature=0.7, max_completion_tokens=1024, top_p=1, stream=False ) return response.choices[0].message.content except Exception as e: logger.error(f"Error calling Groq LLM API: {e}") return "Sorry, I encountered an error while processing your request." @traceable(run_type="chain", name="Response Generator") def generate_response(query): """Generate a response based on the query type""" if is_personal_query(query): # Personal query - use RAG approach context = get_relevant_context(query) logger.info(f"Retrieved context: {context[:200]}...") messages = [ {"role": "system", "content": config.PERSONAL_SYSTEM_PROMPT}, {"role": "user", "content": f"Context about Swayam:\n{context}\n\nQuestion: {query}"} ] response = get_llm_response(messages) return {"response": response, "type": "personal"} else: # General query - use LLM directly messages = [ {"role": "system", "content": config.GENERAL_SYSTEM_PROMPT}, {"role": "user", "content": query} ] response = get_llm_response(messages) return {"response": response, "type": "general"}