tejaskkkk commited on
Commit
bc6b8de
·
verified ·
1 Parent(s): 1e8ab78

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +219 -198
utils.py CHANGED
@@ -1,198 +1,219 @@
1
- import torch
2
- import numpy as np
3
- import pickle
4
- from transformers import AutoTokenizer, AutoModel
5
- from sklearn.metrics.pairwise import cosine_similarity
6
- import logging
7
- import config
8
- from together import Together
9
- import os
10
- from dotenv import load_dotenv
11
- from langsmith.run_helpers import traceable
12
-
13
- # Load environment variables from .env file
14
- load_dotenv()
15
-
16
- logger = logging.getLogger("swayam-chatbot")
17
-
18
- # Initialize Together client with proper error handling
19
- try:
20
- # Try to get API key from environment directly as a fallback
21
- api_key = config.TOGETHER_API_KEY or os.environ.get("TOGETHER_API_KEY")
22
- if not api_key:
23
- logger.warning("No Together API key found. LLM functionality will not work.")
24
- client = None
25
- else:
26
- client = Together(api_key=api_key)
27
- logger.info("Together client initialized successfully")
28
- except Exception as e:
29
- logger.error(f"Failed to initialize Together client: {e}")
30
- client = None
31
-
32
- # Function for mean pooling to get sentence embeddings
33
- def mean_pooling(model_output, attention_mask):
34
- token_embeddings = model_output[0]
35
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
36
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
37
-
38
- # Load embeddings and model once at startup
39
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
- tokenizer = None
41
- model = None
42
- chunks = None
43
- embeddings = None
44
-
45
- def load_resources():
46
- """Load the embedding model and pre-computed embeddings"""
47
- global tokenizer, model, chunks, embeddings
48
-
49
- # Load model and tokenizer
50
- logger.info("Loading embedding model...")
51
- tokenizer = AutoTokenizer.from_pretrained(config.EMBEDDING_MODEL)
52
- model = AutoModel.from_pretrained(config.EMBEDDING_MODEL)
53
- model.to(device)
54
-
55
- # Create embeddings directory if it doesn't exist
56
- os.makedirs(os.path.dirname(config.CHUNK_PATH), exist_ok=True)
57
-
58
- # Load stored chunks and embeddings
59
- logger.info("Loading pre-computed embeddings...")
60
- try:
61
- with open(config.CHUNK_PATH, "rb") as f:
62
- chunks = pickle.load(f)
63
-
64
- with open(config.EMBEDDING_PATH, "rb") as f:
65
- embeddings = pickle.load(f)
66
-
67
- logger.info(f"Loaded {len(chunks)} chunks and embeddings of shape {embeddings.shape}")
68
- return True
69
- except FileNotFoundError as e:
70
- logger.error(f"Error loading embeddings: {e}")
71
- # Try downloading from cloud storage if available
72
- if config.EMBEDDINGS_CLOUD_URL:
73
- logger.info(f"Attempting to download embeddings from cloud storage...")
74
- success = download_embeddings_from_cloud()
75
- if success:
76
- return load_resources() # Try loading again after download
77
- return False
78
-
79
- def download_embeddings_from_cloud():
80
- """Download embeddings from cloud storage"""
81
- try:
82
- import requests
83
-
84
- # Download chunks file
85
- logger.info(f"Downloading chunks from {config.CHUNKS_CLOUD_URL}")
86
- response = requests.get(config.CHUNKS_CLOUD_URL)
87
- if response.status_code == 200:
88
- os.makedirs(os.path.dirname(config.CHUNK_PATH), exist_ok=True)
89
- with open(config.CHUNK_PATH, "wb") as f:
90
- f.write(response.content)
91
- logger.info("Successfully downloaded chunks file")
92
- else:
93
- logger.error(f"Failed to download chunks: {response.status_code}")
94
- return False
95
-
96
- # Download embeddings file
97
- logger.info(f"Downloading embeddings from {config.EMBEDDINGS_CLOUD_URL}")
98
- response = requests.get(config.EMBEDDINGS_CLOUD_URL)
99
- if response.status_code == 200:
100
- with open(config.EMBEDDING_PATH, "wb") as f:
101
- f.write(response.content)
102
- logger.info("Successfully downloaded embeddings file")
103
- return True
104
- else:
105
- logger.error(f"Failed to download embeddings: {response.status_code}")
106
- return False
107
- except Exception as e:
108
- logger.error(f"Error downloading embeddings: {e}")
109
- return False
110
-
111
- def is_personal_query(query):
112
- """Determine if a query is about Swayam or general knowledge"""
113
- query_lower = query.lower()
114
-
115
- # Check if query contains personal keywords
116
- for keyword in config.PERSONAL_KEYWORDS:
117
- if keyword.lower() in query_lower:
118
- logger.info(f"Query classified as PERSONAL due to keyword: {keyword}")
119
- return True
120
-
121
- logger.info("Query classified as GENERAL")
122
- return False
123
-
124
- @traceable(run_type="retriever", name="E5 Vector Retriever")
125
- def get_relevant_context(query, top_k=3):
126
- """Retrieve relevant context from embeddings for a given query"""
127
- if tokenizer is None or model is None:
128
- logger.error("Embedding model not loaded. Call load_resources() first.")
129
- return ""
130
-
131
- # Process query with e5 model - use "query: " prefix for better retrieval
132
- inputs = tokenizer(f"query: {query}", padding=True, truncation=True,
133
- return_tensors="pt").to(device)
134
-
135
- with torch.no_grad():
136
- outputs = model(**inputs)
137
-
138
- # Get query embedding
139
- query_embedding = mean_pooling(outputs, inputs["attention_mask"]).cpu().numpy()
140
-
141
- # Calculate similarity with all chunk embeddings
142
- similarities = cosine_similarity(query_embedding, embeddings)[0]
143
-
144
- # Get top k most similar chunks
145
- top_indices = np.argsort(similarities)[::-1][:top_k]
146
-
147
- # Combine the text from the top chunks
148
- context_parts = []
149
- for idx in top_indices:
150
- _, chunk_text = chunks[idx]
151
- similarity = similarities[idx]
152
- if similarity > 0.2: # Only include reasonably similar chunks
153
- context_parts.append(chunk_text)
154
- logger.info(f"Including chunk with similarity: {similarity:.4f}")
155
-
156
- return "\n\n".join(context_parts)
157
-
158
- @traceable(run_type="llm", name="Together AI LLM")
159
- def get_llm_response(messages):
160
- """Get response from LLM using Together API"""
161
- if client is None:
162
- logger.error("Together client not initialized. Cannot get LLM response.")
163
- return "Sorry, I cannot access the language model at the moment. Please ensure the API key is set correctly."
164
-
165
- try:
166
- response = client.chat.completions.create(
167
- model=config.MODEL_NAME,
168
- messages=messages
169
- )
170
- return response.choices[0].message.content
171
- except Exception as e:
172
- logger.error(f"Error calling LLM API: {e}")
173
- return "Sorry, I encountered an error while processing your request."
174
-
175
- @traceable(run_type="chain", name="Response Generator")
176
- def generate_response(query):
177
- """Generate a response based on the query type"""
178
- if is_personal_query(query):
179
- # Personal query - use RAG approach
180
- context = get_relevant_context(query)
181
- logger.info(f"Retrieved context: {context[:200]}...")
182
-
183
- messages = [
184
- {"role": "system", "content": config.PERSONAL_SYSTEM_PROMPT},
185
- {"role": "user", "content": f"Context about Swayam:\n{context}\n\nQuestion: {query}"}
186
- ]
187
-
188
- response = get_llm_response(messages)
189
- return {"response": response, "type": "personal"}
190
- else:
191
- # General query - use LLM directly
192
- messages = [
193
- {"role": "system", "content": config.GENERAL_SYSTEM_PROMPT},
194
- {"role": "user", "content": query}
195
- ]
196
-
197
- response = get_llm_response(messages)
198
- return {"response": response, "type": "general"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pickle
4
+ from transformers import AutoTokenizer, AutoModel
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ import logging
7
+ import config
8
+ import os
9
+ from dotenv import load_dotenv
10
+ from langsmith.run_helpers import traceable
11
+
12
+ # Load environment variables from .env file
13
+ load_dotenv()
14
+
15
+ logger = logging.getLogger("swayam-chatbot")
16
+
17
+ # Initialize Together client with proper error handling and version compatibility
18
+ try:
19
+ # Try different import patterns for different versions of together library
20
+ try:
21
+ from together import Together
22
+ except ImportError:
23
+ try:
24
+ from together.client import Together
25
+ except ImportError:
26
+ import together
27
+ Together = together.Together
28
+
29
+ # Try to get API key from environment directly as a fallback
30
+ api_key = config.TOGETHER_API_KEY or os.environ.get("TOGETHER_API_KEY")
31
+ if not api_key:
32
+ logger.warning("No Together API key found. LLM functionality will not work.")
33
+ client = None
34
+ else:
35
+ client = Together(api_key=api_key)
36
+ logger.info("Together client initialized successfully")
37
+ except Exception as e:
38
+ logger.error(f"Failed to initialize Together client: {e}")
39
+ client = None
40
+
41
+ # Function for mean pooling to get sentence embeddings
42
+ def mean_pooling(model_output, attention_mask):
43
+ token_embeddings = model_output[0]
44
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
45
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
46
+
47
+ # Load embeddings and model once at startup
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ tokenizer = None
50
+ model = None
51
+ chunks = None
52
+ embeddings = None
53
+
54
+ def load_resources():
55
+ """Load the embedding model and pre-computed embeddings"""
56
+ global tokenizer, model, chunks, embeddings
57
+
58
+ # Load model and tokenizer
59
+ logger.info("Loading embedding model...")
60
+ tokenizer = AutoTokenizer.from_pretrained(config.EMBEDDING_MODEL)
61
+ model = AutoModel.from_pretrained(config.EMBEDDING_MODEL)
62
+ model.to(device)
63
+
64
+ # Create embeddings directory if it doesn't exist
65
+ os.makedirs(os.path.dirname(config.CHUNK_PATH), exist_ok=True)
66
+
67
+ # Load stored chunks and embeddings
68
+ logger.info("Loading pre-computed embeddings...")
69
+ try:
70
+ with open(config.CHUNK_PATH, "rb") as f:
71
+ chunks = pickle.load(f)
72
+
73
+ with open(config.EMBEDDING_PATH, "rb") as f:
74
+ embeddings = pickle.load(f)
75
+
76
+ logger.info(f"Loaded {len(chunks)} chunks and embeddings of shape {embeddings.shape}")
77
+ return True
78
+ except FileNotFoundError as e:
79
+ logger.error(f"Error loading embeddings: {e}")
80
+ # Try downloading from cloud storage if available
81
+ if config.EMBEDDINGS_CLOUD_URL:
82
+ logger.info(f"Attempting to download embeddings from cloud storage...")
83
+ success = download_embeddings_from_cloud()
84
+ if success:
85
+ return load_resources() # Try loading again after download
86
+ return False
87
+
88
+ def download_embeddings_from_cloud():
89
+ """Download embeddings from cloud storage"""
90
+ try:
91
+ import requests
92
+
93
+ # Download chunks file
94
+ logger.info(f"Downloading chunks from {config.CHUNKS_CLOUD_URL}")
95
+ response = requests.get(config.CHUNKS_CLOUD_URL)
96
+ if response.status_code == 200:
97
+ os.makedirs(os.path.dirname(config.CHUNK_PATH), exist_ok=True)
98
+ with open(config.CHUNK_PATH, "wb") as f:
99
+ f.write(response.content)
100
+ logger.info("Successfully downloaded chunks file")
101
+ else:
102
+ logger.error(f"Failed to download chunks: {response.status_code}")
103
+ return False
104
+
105
+ # Download embeddings file
106
+ logger.info(f"Downloading embeddings from {config.EMBEDDINGS_CLOUD_URL}")
107
+ response = requests.get(config.EMBEDDINGS_CLOUD_URL)
108
+ if response.status_code == 200:
109
+ with open(config.EMBEDDING_PATH, "wb") as f:
110
+ f.write(response.content)
111
+ logger.info("Successfully downloaded embeddings file")
112
+ return True
113
+ else:
114
+ logger.error(f"Failed to download embeddings: {response.status_code}")
115
+ return False
116
+ except Exception as e:
117
+ logger.error(f"Error downloading embeddings: {e}")
118
+ return False
119
+
120
+ def is_personal_query(query):
121
+ """Determine if a query is about Swayam or general knowledge"""
122
+ query_lower = query.lower()
123
+
124
+ # Check if query contains personal keywords
125
+ for keyword in config.PERSONAL_KEYWORDS:
126
+ if keyword.lower() in query_lower:
127
+ logger.info(f"Query classified as PERSONAL due to keyword: {keyword}")
128
+ return True
129
+
130
+ logger.info("Query classified as GENERAL")
131
+ return False
132
+
133
+ @traceable(run_type="retriever", name="E5 Vector Retriever")
134
+ def get_relevant_context(query, top_k=3):
135
+ """Retrieve relevant context from embeddings for a given query"""
136
+ if tokenizer is None or model is None:
137
+ logger.error("Embedding model not loaded. Call load_resources() first.")
138
+ return ""
139
+
140
+ # Process query with e5 model - use "query: " prefix for better retrieval
141
+ inputs = tokenizer(f"query: {query}", padding=True, truncation=True,
142
+ return_tensors="pt").to(device)
143
+
144
+ with torch.no_grad():
145
+ outputs = model(**inputs)
146
+
147
+ # Get query embedding
148
+ query_embedding = mean_pooling(outputs, inputs["attention_mask"]).cpu().numpy()
149
+
150
+ # Calculate similarity with all chunk embeddings
151
+ similarities = cosine_similarity(query_embedding, embeddings)[0]
152
+
153
+ # Get top k most similar chunks
154
+ top_indices = np.argsort(similarities)[::-1][:top_k]
155
+
156
+ # Combine the text from the top chunks
157
+ context_parts = []
158
+ for idx in top_indices:
159
+ _, chunk_text = chunks[idx]
160
+ similarity = similarities[idx]
161
+ if similarity > 0.2: # Only include reasonably similar chunks
162
+ context_parts.append(chunk_text)
163
+ logger.info(f"Including chunk with similarity: {similarity:.4f}")
164
+
165
+ return "\n\n".join(context_parts)
166
+
167
+ @traceable(run_type="llm", name="Together AI LLM")
168
+ def get_llm_response(messages):
169
+ """Get response from LLM using Together API"""
170
+ if client is None:
171
+ logger.error("Together client not initialized. Cannot get LLM response.")
172
+ return "Sorry, I cannot access the language model at the moment. Please ensure the API key is set correctly."
173
+
174
+ try:
175
+ response = client.chat.completions.create(
176
+ model=config.MODEL_NAME,
177
+ messages=messages
178
+ )
179
+ return response.choices[0].message.content
180
+ except AttributeError:
181
+ # Handle older version of together library
182
+ try:
183
+ response = client.completions.create(
184
+ model=config.MODEL_NAME,
185
+ prompt=messages[-1]["content"],
186
+ max_tokens=1000
187
+ )
188
+ return response.choices[0].text
189
+ except Exception as e:
190
+ logger.error(f"Error with fallback API call: {e}")
191
+ return "Sorry, I encountered an error while processing your request."
192
+ except Exception as e:
193
+ logger.error(f"Error calling LLM API: {e}")
194
+ return "Sorry, I encountered an error while processing your request."
195
+
196
+ @traceable(run_type="chain", name="Response Generator")
197
+ def generate_response(query):
198
+ """Generate a response based on the query type"""
199
+ if is_personal_query(query):
200
+ # Personal query - use RAG approach
201
+ context = get_relevant_context(query)
202
+ logger.info(f"Retrieved context: {context[:200]}...")
203
+
204
+ messages = [
205
+ {"role": "system", "content": config.PERSONAL_SYSTEM_PROMPT},
206
+ {"role": "user", "content": f"Context about Swayam:\n{context}\n\nQuestion: {query}"}
207
+ ]
208
+
209
+ response = get_llm_response(messages)
210
+ return {"response": response, "type": "personal"}
211
+ else:
212
+ # General query - use LLM directly
213
+ messages = [
214
+ {"role": "system", "content": config.GENERAL_SYSTEM_PROMPT},
215
+ {"role": "user", "content": query}
216
+ ]
217
+
218
+ response = get_llm_response(messages)
219
+ return {"response": response, "type": "general"}