File size: 7,734 Bytes
bc6b8de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6255540
bc6b8de
6255540
bc6b8de
 
6255540
bc6b8de
6255540
bc6b8de
 
6255540
 
bc6b8de
6255540
bc6b8de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6255540
bc6b8de
6255540
bc6b8de
6255540
bc6b8de
 
 
 
 
6255540
 
 
 
 
bc6b8de
 
 
6255540
bc6b8de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6255540
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import torch
import numpy as np
import pickle
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import logging
import config
import os
from dotenv import load_dotenv
from langsmith.run_helpers import traceable

# Load environment variables from .env file
load_dotenv()

logger = logging.getLogger("swayam-chatbot")

# Initialize Groq client with proper error handling
try:
    from groq import Groq
    
    # Try to get API key from environment directly as a fallback
    api_key = config.GROQ_API_KEY or os.environ.get("GROQ_API_KEY")
    if not api_key:
        logger.warning("No Groq API key found. LLM functionality will not work.")
        client = None
    else:
        client = Groq(api_key=api_key)
        logger.info("Groq client initialized successfully")
except Exception as e:
    logger.error(f"Failed to initialize Groq client: {e}")
    client = None

# Function for mean pooling to get sentence embeddings
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

# Load embeddings and model once at startup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = None
model = None
chunks = None
embeddings = None

def load_resources():
    """Load the embedding model and pre-computed embeddings"""
    global tokenizer, model, chunks, embeddings
    
    # Load model and tokenizer
    logger.info("Loading embedding model...")
    tokenizer = AutoTokenizer.from_pretrained(config.EMBEDDING_MODEL)
    model = AutoModel.from_pretrained(config.EMBEDDING_MODEL)
    model.to(device)
    
    # Create embeddings directory if it doesn't exist
    os.makedirs(os.path.dirname(config.CHUNK_PATH), exist_ok=True)
    
    # Load stored chunks and embeddings
    logger.info("Loading pre-computed embeddings...")
    try:
        with open(config.CHUNK_PATH, "rb") as f:
            chunks = pickle.load(f)
        
        with open(config.EMBEDDING_PATH, "rb") as f:
            embeddings = pickle.load(f)
            
        logger.info(f"Loaded {len(chunks)} chunks and embeddings of shape {embeddings.shape}")
        return True
    except FileNotFoundError as e:
        logger.error(f"Error loading embeddings: {e}")
        # Try downloading from cloud storage if available
        if config.EMBEDDINGS_CLOUD_URL:
            logger.info(f"Attempting to download embeddings from cloud storage...")
            success = download_embeddings_from_cloud()
            if success:
                return load_resources()  # Try loading again after download
        return False

def download_embeddings_from_cloud():
    """Download embeddings from cloud storage"""
    try:
        import requests
        
        # Download chunks file
        logger.info(f"Downloading chunks from {config.CHUNKS_CLOUD_URL}")
        response = requests.get(config.CHUNKS_CLOUD_URL)
        if response.status_code == 200:
            os.makedirs(os.path.dirname(config.CHUNK_PATH), exist_ok=True)
            with open(config.CHUNK_PATH, "wb") as f:
                f.write(response.content)
            logger.info("Successfully downloaded chunks file")
        else:
            logger.error(f"Failed to download chunks: {response.status_code}")
            return False
            
        # Download embeddings file
        logger.info(f"Downloading embeddings from {config.EMBEDDINGS_CLOUD_URL}")
        response = requests.get(config.EMBEDDINGS_CLOUD_URL)
        if response.status_code == 200:
            with open(config.EMBEDDING_PATH, "wb") as f:
                f.write(response.content)
            logger.info("Successfully downloaded embeddings file")
            return True
        else:
            logger.error(f"Failed to download embeddings: {response.status_code}")
            return False
    except Exception as e:
        logger.error(f"Error downloading embeddings: {e}")
        return False

def is_personal_query(query):
    """Determine if a query is about Swayam or general knowledge"""
    query_lower = query.lower()
    
    # Check if query contains personal keywords
    for keyword in config.PERSONAL_KEYWORDS:
        if keyword.lower() in query_lower:
            logger.info(f"Query classified as PERSONAL due to keyword: {keyword}")
            return True
    
    logger.info("Query classified as GENERAL")
    return False

@traceable(run_type="retriever", name="E5 Vector Retriever")
def get_relevant_context(query, top_k=3):
    """Retrieve relevant context from embeddings for a given query"""
    if tokenizer is None or model is None:
        logger.error("Embedding model not loaded. Call load_resources() first.")
        return ""
    
    # Process query with e5 model - use "query: " prefix for better retrieval
    inputs = tokenizer(f"query: {query}", padding=True, truncation=True, 
                      return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get query embedding
    query_embedding = mean_pooling(outputs, inputs["attention_mask"]).cpu().numpy()
    
    # Calculate similarity with all chunk embeddings
    similarities = cosine_similarity(query_embedding, embeddings)[0]
    
    # Get top k most similar chunks
    top_indices = np.argsort(similarities)[::-1][:top_k]
    
    # Combine the text from the top chunks
    context_parts = []
    for idx in top_indices:
        _, chunk_text = chunks[idx]
        similarity = similarities[idx]
        if similarity > 0.2:  # Only include reasonably similar chunks
            context_parts.append(chunk_text)
            logger.info(f"Including chunk with similarity: {similarity:.4f}")
    
    return "\n\n".join(context_parts)

@traceable(run_type="llm", name="Groq LLM")
def get_llm_response(messages):
    """Get response from LLM using Groq API"""
    if client is None:
        logger.error("Groq client not initialized. Cannot get LLM response.")
        return "Sorry, I cannot access the language model at the moment. Please ensure the API key is set correctly."
    
    try:
        response = client.chat.completions.create(
            model=config.MODEL_NAME,
            messages=messages,
            temperature=0.7,
            max_completion_tokens=1024,
            top_p=1,
            stream=False
        )
        return response.choices[0].message.content
    except Exception as e:
        logger.error(f"Error calling Groq LLM API: {e}")
        return "Sorry, I encountered an error while processing your request."

@traceable(run_type="chain", name="Response Generator")
def generate_response(query):
    """Generate a response based on the query type"""
    if is_personal_query(query):
        # Personal query - use RAG approach
        context = get_relevant_context(query)
        logger.info(f"Retrieved context: {context[:200]}...")
        
        messages = [
            {"role": "system", "content": config.PERSONAL_SYSTEM_PROMPT},
            {"role": "user", "content": f"Context about Swayam:\n{context}\n\nQuestion: {query}"}
        ]
        
        response = get_llm_response(messages)
        return {"response": response, "type": "personal"}
    else:
        # General query - use LLM directly
        messages = [
            {"role": "system", "content": config.GENERAL_SYSTEM_PROMPT},
            {"role": "user", "content": query}
        ]
        
        response = get_llm_response(messages)
        return {"response": response, "type": "general"}