Spaces:
Running
Running
File size: 8,454 Bytes
bc6b8de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
import torch
import numpy as np
import pickle
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import logging
import config
import os
from dotenv import load_dotenv
from langsmith.run_helpers import traceable
# Load environment variables from .env file
load_dotenv()
logger = logging.getLogger("swayam-chatbot")
# Initialize Together client with proper error handling and version compatibility
try:
# Try different import patterns for different versions of together library
try:
from together import Together
except ImportError:
try:
from together.client import Together
except ImportError:
import together
Together = together.Together
# Try to get API key from environment directly as a fallback
api_key = config.TOGETHER_API_KEY or os.environ.get("TOGETHER_API_KEY")
if not api_key:
logger.warning("No Together API key found. LLM functionality will not work.")
client = None
else:
client = Together(api_key=api_key)
logger.info("Together client initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize Together client: {e}")
client = None
# Function for mean pooling to get sentence embeddings
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# Load embeddings and model once at startup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = None
model = None
chunks = None
embeddings = None
def load_resources():
"""Load the embedding model and pre-computed embeddings"""
global tokenizer, model, chunks, embeddings
# Load model and tokenizer
logger.info("Loading embedding model...")
tokenizer = AutoTokenizer.from_pretrained(config.EMBEDDING_MODEL)
model = AutoModel.from_pretrained(config.EMBEDDING_MODEL)
model.to(device)
# Create embeddings directory if it doesn't exist
os.makedirs(os.path.dirname(config.CHUNK_PATH), exist_ok=True)
# Load stored chunks and embeddings
logger.info("Loading pre-computed embeddings...")
try:
with open(config.CHUNK_PATH, "rb") as f:
chunks = pickle.load(f)
with open(config.EMBEDDING_PATH, "rb") as f:
embeddings = pickle.load(f)
logger.info(f"Loaded {len(chunks)} chunks and embeddings of shape {embeddings.shape}")
return True
except FileNotFoundError as e:
logger.error(f"Error loading embeddings: {e}")
# Try downloading from cloud storage if available
if config.EMBEDDINGS_CLOUD_URL:
logger.info(f"Attempting to download embeddings from cloud storage...")
success = download_embeddings_from_cloud()
if success:
return load_resources() # Try loading again after download
return False
def download_embeddings_from_cloud():
"""Download embeddings from cloud storage"""
try:
import requests
# Download chunks file
logger.info(f"Downloading chunks from {config.CHUNKS_CLOUD_URL}")
response = requests.get(config.CHUNKS_CLOUD_URL)
if response.status_code == 200:
os.makedirs(os.path.dirname(config.CHUNK_PATH), exist_ok=True)
with open(config.CHUNK_PATH, "wb") as f:
f.write(response.content)
logger.info("Successfully downloaded chunks file")
else:
logger.error(f"Failed to download chunks: {response.status_code}")
return False
# Download embeddings file
logger.info(f"Downloading embeddings from {config.EMBEDDINGS_CLOUD_URL}")
response = requests.get(config.EMBEDDINGS_CLOUD_URL)
if response.status_code == 200:
with open(config.EMBEDDING_PATH, "wb") as f:
f.write(response.content)
logger.info("Successfully downloaded embeddings file")
return True
else:
logger.error(f"Failed to download embeddings: {response.status_code}")
return False
except Exception as e:
logger.error(f"Error downloading embeddings: {e}")
return False
def is_personal_query(query):
"""Determine if a query is about Swayam or general knowledge"""
query_lower = query.lower()
# Check if query contains personal keywords
for keyword in config.PERSONAL_KEYWORDS:
if keyword.lower() in query_lower:
logger.info(f"Query classified as PERSONAL due to keyword: {keyword}")
return True
logger.info("Query classified as GENERAL")
return False
@traceable(run_type="retriever", name="E5 Vector Retriever")
def get_relevant_context(query, top_k=3):
"""Retrieve relevant context from embeddings for a given query"""
if tokenizer is None or model is None:
logger.error("Embedding model not loaded. Call load_resources() first.")
return ""
# Process query with e5 model - use "query: " prefix for better retrieval
inputs = tokenizer(f"query: {query}", padding=True, truncation=True,
return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
# Get query embedding
query_embedding = mean_pooling(outputs, inputs["attention_mask"]).cpu().numpy()
# Calculate similarity with all chunk embeddings
similarities = cosine_similarity(query_embedding, embeddings)[0]
# Get top k most similar chunks
top_indices = np.argsort(similarities)[::-1][:top_k]
# Combine the text from the top chunks
context_parts = []
for idx in top_indices:
_, chunk_text = chunks[idx]
similarity = similarities[idx]
if similarity > 0.2: # Only include reasonably similar chunks
context_parts.append(chunk_text)
logger.info(f"Including chunk with similarity: {similarity:.4f}")
return "\n\n".join(context_parts)
@traceable(run_type="llm", name="Together AI LLM")
def get_llm_response(messages):
"""Get response from LLM using Together API"""
if client is None:
logger.error("Together client not initialized. Cannot get LLM response.")
return "Sorry, I cannot access the language model at the moment. Please ensure the API key is set correctly."
try:
response = client.chat.completions.create(
model=config.MODEL_NAME,
messages=messages
)
return response.choices[0].message.content
except AttributeError:
# Handle older version of together library
try:
response = client.completions.create(
model=config.MODEL_NAME,
prompt=messages[-1]["content"],
max_tokens=1000
)
return response.choices[0].text
except Exception as e:
logger.error(f"Error with fallback API call: {e}")
return "Sorry, I encountered an error while processing your request."
except Exception as e:
logger.error(f"Error calling LLM API: {e}")
return "Sorry, I encountered an error while processing your request."
@traceable(run_type="chain", name="Response Generator")
def generate_response(query):
"""Generate a response based on the query type"""
if is_personal_query(query):
# Personal query - use RAG approach
context = get_relevant_context(query)
logger.info(f"Retrieved context: {context[:200]}...")
messages = [
{"role": "system", "content": config.PERSONAL_SYSTEM_PROMPT},
{"role": "user", "content": f"Context about Swayam:\n{context}\n\nQuestion: {query}"}
]
response = get_llm_response(messages)
return {"response": response, "type": "personal"}
else:
# General query - use LLM directly
messages = [
{"role": "system", "content": config.GENERAL_SYSTEM_PROMPT},
{"role": "user", "content": query}
]
response = get_llm_response(messages)
return {"response": response, "type": "general"}
|