portfolio_backend / utils.py
tejaskkkk's picture
Update utils.py
bc6b8de verified
raw
history blame
8.45 kB
import torch
import numpy as np
import pickle
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import logging
import config
import os
from dotenv import load_dotenv
from langsmith.run_helpers import traceable
# Load environment variables from .env file
load_dotenv()
logger = logging.getLogger("swayam-chatbot")
# Initialize Together client with proper error handling and version compatibility
try:
# Try different import patterns for different versions of together library
try:
from together import Together
except ImportError:
try:
from together.client import Together
except ImportError:
import together
Together = together.Together
# Try to get API key from environment directly as a fallback
api_key = config.TOGETHER_API_KEY or os.environ.get("TOGETHER_API_KEY")
if not api_key:
logger.warning("No Together API key found. LLM functionality will not work.")
client = None
else:
client = Together(api_key=api_key)
logger.info("Together client initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize Together client: {e}")
client = None
# Function for mean pooling to get sentence embeddings
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# Load embeddings and model once at startup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = None
model = None
chunks = None
embeddings = None
def load_resources():
"""Load the embedding model and pre-computed embeddings"""
global tokenizer, model, chunks, embeddings
# Load model and tokenizer
logger.info("Loading embedding model...")
tokenizer = AutoTokenizer.from_pretrained(config.EMBEDDING_MODEL)
model = AutoModel.from_pretrained(config.EMBEDDING_MODEL)
model.to(device)
# Create embeddings directory if it doesn't exist
os.makedirs(os.path.dirname(config.CHUNK_PATH), exist_ok=True)
# Load stored chunks and embeddings
logger.info("Loading pre-computed embeddings...")
try:
with open(config.CHUNK_PATH, "rb") as f:
chunks = pickle.load(f)
with open(config.EMBEDDING_PATH, "rb") as f:
embeddings = pickle.load(f)
logger.info(f"Loaded {len(chunks)} chunks and embeddings of shape {embeddings.shape}")
return True
except FileNotFoundError as e:
logger.error(f"Error loading embeddings: {e}")
# Try downloading from cloud storage if available
if config.EMBEDDINGS_CLOUD_URL:
logger.info(f"Attempting to download embeddings from cloud storage...")
success = download_embeddings_from_cloud()
if success:
return load_resources() # Try loading again after download
return False
def download_embeddings_from_cloud():
"""Download embeddings from cloud storage"""
try:
import requests
# Download chunks file
logger.info(f"Downloading chunks from {config.CHUNKS_CLOUD_URL}")
response = requests.get(config.CHUNKS_CLOUD_URL)
if response.status_code == 200:
os.makedirs(os.path.dirname(config.CHUNK_PATH), exist_ok=True)
with open(config.CHUNK_PATH, "wb") as f:
f.write(response.content)
logger.info("Successfully downloaded chunks file")
else:
logger.error(f"Failed to download chunks: {response.status_code}")
return False
# Download embeddings file
logger.info(f"Downloading embeddings from {config.EMBEDDINGS_CLOUD_URL}")
response = requests.get(config.EMBEDDINGS_CLOUD_URL)
if response.status_code == 200:
with open(config.EMBEDDING_PATH, "wb") as f:
f.write(response.content)
logger.info("Successfully downloaded embeddings file")
return True
else:
logger.error(f"Failed to download embeddings: {response.status_code}")
return False
except Exception as e:
logger.error(f"Error downloading embeddings: {e}")
return False
def is_personal_query(query):
"""Determine if a query is about Swayam or general knowledge"""
query_lower = query.lower()
# Check if query contains personal keywords
for keyword in config.PERSONAL_KEYWORDS:
if keyword.lower() in query_lower:
logger.info(f"Query classified as PERSONAL due to keyword: {keyword}")
return True
logger.info("Query classified as GENERAL")
return False
@traceable(run_type="retriever", name="E5 Vector Retriever")
def get_relevant_context(query, top_k=3):
"""Retrieve relevant context from embeddings for a given query"""
if tokenizer is None or model is None:
logger.error("Embedding model not loaded. Call load_resources() first.")
return ""
# Process query with e5 model - use "query: " prefix for better retrieval
inputs = tokenizer(f"query: {query}", padding=True, truncation=True,
return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
# Get query embedding
query_embedding = mean_pooling(outputs, inputs["attention_mask"]).cpu().numpy()
# Calculate similarity with all chunk embeddings
similarities = cosine_similarity(query_embedding, embeddings)[0]
# Get top k most similar chunks
top_indices = np.argsort(similarities)[::-1][:top_k]
# Combine the text from the top chunks
context_parts = []
for idx in top_indices:
_, chunk_text = chunks[idx]
similarity = similarities[idx]
if similarity > 0.2: # Only include reasonably similar chunks
context_parts.append(chunk_text)
logger.info(f"Including chunk with similarity: {similarity:.4f}")
return "\n\n".join(context_parts)
@traceable(run_type="llm", name="Together AI LLM")
def get_llm_response(messages):
"""Get response from LLM using Together API"""
if client is None:
logger.error("Together client not initialized. Cannot get LLM response.")
return "Sorry, I cannot access the language model at the moment. Please ensure the API key is set correctly."
try:
response = client.chat.completions.create(
model=config.MODEL_NAME,
messages=messages
)
return response.choices[0].message.content
except AttributeError:
# Handle older version of together library
try:
response = client.completions.create(
model=config.MODEL_NAME,
prompt=messages[-1]["content"],
max_tokens=1000
)
return response.choices[0].text
except Exception as e:
logger.error(f"Error with fallback API call: {e}")
return "Sorry, I encountered an error while processing your request."
except Exception as e:
logger.error(f"Error calling LLM API: {e}")
return "Sorry, I encountered an error while processing your request."
@traceable(run_type="chain", name="Response Generator")
def generate_response(query):
"""Generate a response based on the query type"""
if is_personal_query(query):
# Personal query - use RAG approach
context = get_relevant_context(query)
logger.info(f"Retrieved context: {context[:200]}...")
messages = [
{"role": "system", "content": config.PERSONAL_SYSTEM_PROMPT},
{"role": "user", "content": f"Context about Swayam:\n{context}\n\nQuestion: {query}"}
]
response = get_llm_response(messages)
return {"response": response, "type": "personal"}
else:
# General query - use LLM directly
messages = [
{"role": "system", "content": config.GENERAL_SYSTEM_PROMPT},
{"role": "user", "content": query}
]
response = get_llm_response(messages)
return {"response": response, "type": "general"}