File size: 8,454 Bytes
bc6b8de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import torch
import numpy as np
import pickle
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import logging
import config
import os
from dotenv import load_dotenv
from langsmith.run_helpers import traceable

# Load environment variables from .env file
load_dotenv()

logger = logging.getLogger("swayam-chatbot")

# Initialize Together client with proper error handling and version compatibility
try:
    # Try different import patterns for different versions of together library
    try:
        from together import Together
    except ImportError:
        try:
            from together.client import Together
        except ImportError:
            import together
            Together = together.Together
    
    # Try to get API key from environment directly as a fallback
    api_key = config.TOGETHER_API_KEY or os.environ.get("TOGETHER_API_KEY")
    if not api_key:
        logger.warning("No Together API key found. LLM functionality will not work.")
        client = None
    else:
        client = Together(api_key=api_key)
        logger.info("Together client initialized successfully")
except Exception as e:
    logger.error(f"Failed to initialize Together client: {e}")
    client = None

# Function for mean pooling to get sentence embeddings
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

# Load embeddings and model once at startup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = None
model = None
chunks = None
embeddings = None

def load_resources():
    """Load the embedding model and pre-computed embeddings"""
    global tokenizer, model, chunks, embeddings
    
    # Load model and tokenizer
    logger.info("Loading embedding model...")
    tokenizer = AutoTokenizer.from_pretrained(config.EMBEDDING_MODEL)
    model = AutoModel.from_pretrained(config.EMBEDDING_MODEL)
    model.to(device)
    
    # Create embeddings directory if it doesn't exist
    os.makedirs(os.path.dirname(config.CHUNK_PATH), exist_ok=True)
    
    # Load stored chunks and embeddings
    logger.info("Loading pre-computed embeddings...")
    try:
        with open(config.CHUNK_PATH, "rb") as f:
            chunks = pickle.load(f)
        
        with open(config.EMBEDDING_PATH, "rb") as f:
            embeddings = pickle.load(f)
            
        logger.info(f"Loaded {len(chunks)} chunks and embeddings of shape {embeddings.shape}")
        return True
    except FileNotFoundError as e:
        logger.error(f"Error loading embeddings: {e}")
        # Try downloading from cloud storage if available
        if config.EMBEDDINGS_CLOUD_URL:
            logger.info(f"Attempting to download embeddings from cloud storage...")
            success = download_embeddings_from_cloud()
            if success:
                return load_resources()  # Try loading again after download
        return False

def download_embeddings_from_cloud():
    """Download embeddings from cloud storage"""
    try:
        import requests
        
        # Download chunks file
        logger.info(f"Downloading chunks from {config.CHUNKS_CLOUD_URL}")
        response = requests.get(config.CHUNKS_CLOUD_URL)
        if response.status_code == 200:
            os.makedirs(os.path.dirname(config.CHUNK_PATH), exist_ok=True)
            with open(config.CHUNK_PATH, "wb") as f:
                f.write(response.content)
            logger.info("Successfully downloaded chunks file")
        else:
            logger.error(f"Failed to download chunks: {response.status_code}")
            return False
            
        # Download embeddings file
        logger.info(f"Downloading embeddings from {config.EMBEDDINGS_CLOUD_URL}")
        response = requests.get(config.EMBEDDINGS_CLOUD_URL)
        if response.status_code == 200:
            with open(config.EMBEDDING_PATH, "wb") as f:
                f.write(response.content)
            logger.info("Successfully downloaded embeddings file")
            return True
        else:
            logger.error(f"Failed to download embeddings: {response.status_code}")
            return False
    except Exception as e:
        logger.error(f"Error downloading embeddings: {e}")
        return False

def is_personal_query(query):
    """Determine if a query is about Swayam or general knowledge"""
    query_lower = query.lower()
    
    # Check if query contains personal keywords
    for keyword in config.PERSONAL_KEYWORDS:
        if keyword.lower() in query_lower:
            logger.info(f"Query classified as PERSONAL due to keyword: {keyword}")
            return True
    
    logger.info("Query classified as GENERAL")
    return False

@traceable(run_type="retriever", name="E5 Vector Retriever")
def get_relevant_context(query, top_k=3):
    """Retrieve relevant context from embeddings for a given query"""
    if tokenizer is None or model is None:
        logger.error("Embedding model not loaded. Call load_resources() first.")
        return ""
    
    # Process query with e5 model - use "query: " prefix for better retrieval
    inputs = tokenizer(f"query: {query}", padding=True, truncation=True, 
                      return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get query embedding
    query_embedding = mean_pooling(outputs, inputs["attention_mask"]).cpu().numpy()
    
    # Calculate similarity with all chunk embeddings
    similarities = cosine_similarity(query_embedding, embeddings)[0]
    
    # Get top k most similar chunks
    top_indices = np.argsort(similarities)[::-1][:top_k]
    
    # Combine the text from the top chunks
    context_parts = []
    for idx in top_indices:
        _, chunk_text = chunks[idx]
        similarity = similarities[idx]
        if similarity > 0.2:  # Only include reasonably similar chunks
            context_parts.append(chunk_text)
            logger.info(f"Including chunk with similarity: {similarity:.4f}")
    
    return "\n\n".join(context_parts)

@traceable(run_type="llm", name="Together AI LLM")
def get_llm_response(messages):
    """Get response from LLM using Together API"""
    if client is None:
        logger.error("Together client not initialized. Cannot get LLM response.")
        return "Sorry, I cannot access the language model at the moment. Please ensure the API key is set correctly."
    
    try:
        response = client.chat.completions.create(
            model=config.MODEL_NAME,
            messages=messages
        )
        return response.choices[0].message.content
    except AttributeError:
        # Handle older version of together library
        try:
            response = client.completions.create(
                model=config.MODEL_NAME,
                prompt=messages[-1]["content"],
                max_tokens=1000
            )
            return response.choices[0].text
        except Exception as e:
            logger.error(f"Error with fallback API call: {e}")
            return "Sorry, I encountered an error while processing your request."
    except Exception as e:
        logger.error(f"Error calling LLM API: {e}")
        return "Sorry, I encountered an error while processing your request."

@traceable(run_type="chain", name="Response Generator")
def generate_response(query):
    """Generate a response based on the query type"""
    if is_personal_query(query):
        # Personal query - use RAG approach
        context = get_relevant_context(query)
        logger.info(f"Retrieved context: {context[:200]}...")
        
        messages = [
            {"role": "system", "content": config.PERSONAL_SYSTEM_PROMPT},
            {"role": "user", "content": f"Context about Swayam:\n{context}\n\nQuestion: {query}"}
        ]
        
        response = get_llm_response(messages)
        return {"response": response, "type": "personal"}
    else:
        # General query - use LLM directly
        messages = [
            {"role": "system", "content": config.GENERAL_SYSTEM_PROMPT},
            {"role": "user", "content": query}
        ]
        
        response = get_llm_response(messages)
        return {"response": response, "type": "general"}