| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.utils.data import Dataset, DataLoader |
| | from datasets import load_dataset |
| | from transformers import AutoTokenizer |
| | from tqdm import tqdm |
| | import math |
| | import speech_recognition as sr |
| | import pyttsx3 |
| | from googlesearch import search |
| | import warnings |
| | from typing import List, Dict, Union |
| |
|
| | |
| | warnings.filterwarnings("ignore") |
| |
|
| | class WebSearchWrapper: |
| | """Wrapper for web search with caching""" |
| | def __init__(self, cache_size: int = 100): |
| | self.cache: Dict[str, List[str]] = {} |
| | self.cache_size = cache_size |
| | |
| | def search(self, query: str, num_results: int = 3) -> List[str]: |
| | """Perform web search with caching""" |
| | if query.lower() in self.cache: |
| | return self.cache[query.lower()] |
| | |
| | try: |
| | search_results = list(search(query, num_results=num_results, stop=num_results, pause=2)) |
| | self._add_to_cache(query, search_results) |
| | return search_results |
| | except Exception as e: |
| | print(f"Web search error: {e}") |
| | return [] |
| | |
| | def _add_to_cache(self, query: str, results: List[str]): |
| | """Add results to cache with LRU eviction policy""" |
| | if len(self.cache) >= self.cache_size: |
| | self.cache.pop(next(iter(self.cache))) |
| | self.cache[query.lower()] = results |
| |
|
| | class FullChatDataset(Dataset): |
| | def __init__(self, dataset_names=["blended_skill_talk", "conv_ai_2", "social_i_qa"], max_length=256): |
| | self.datasets = [] |
| | |
| | for name in dataset_names: |
| | try: |
| | dataset = load_dataset(name, split="train") |
| | self.datasets.append(dataset) |
| | except Exception as e: |
| | print(f"Failed to load dataset {name}: {e}") |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
| | self.tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
| | self.max_length = max_length |
| | |
| | def __len__(self): |
| | return sum(len(d) for d in self.datasets) |
| | |
| | def __getitem__(self, idx): |
| | for dataset in self.datasets: |
| | if idx < len(dataset): |
| | item = dataset[idx] |
| | break |
| | idx -= len(dataset) |
| | |
| | if 'dialog' in item: |
| | dialog = item['dialog'] |
| | elif 'messages' in item: |
| | dialog = [msg['text'] for msg in item['messages']] |
| | else: |
| | dialog = [v for k, v in item.items() if isinstance(v, str)] |
| | |
| | context = " [SEP] ".join(dialog[:-1]) |
| | response = dialog[-1] |
| | |
| | inputs = self.tokenizer( |
| | context, |
| | text_pair=response, |
| | max_length=self.max_length, |
| | padding='max_length', |
| | truncation=True, |
| | return_tensors="pt" |
| | ) |
| | |
| | return { |
| | 'input_ids': inputs['input_ids'].flatten(), |
| | 'attention_mask': inputs['attention_mask'].flatten(), |
| | 'labels': inputs['input_ids'].flatten() |
| | } |
| |
|
| | class SimpleTransformerModel(nn.Module): |
| | def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3): |
| | super().__init__() |
| | self.embedding = nn.Embedding(vocab_size, d_model) |
| | self.pos_encoder = PositionalEncoding(d_model) |
| | encoder_layer = nn.TransformerEncoderLayer(d_model, nhead) |
| | self.transformer = nn.TransformerEncoder(encoder_layer, num_layers) |
| | self.fc = nn.Linear(d_model, vocab_size) |
| | |
| | def forward(self, x, mask=None): |
| | x = self.embedding(x) |
| | x = self.pos_encoder(x) |
| | x = self.transformer(x, mask) |
| | return self.fc(x) |
| |
|
| | class PositionalEncoding(nn.Module): |
| | def __init__(self, d_model, max_len=500): |
| | super().__init__() |
| | position = torch.arange(max_len).unsqueeze(1) |
| | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) |
| | pe = torch.zeros(max_len, d_model) |
| | pe[:, 0::2] = torch.sin(position * div_term) |
| | pe[:, 1::2] = torch.cos(position * div_term) |
| | self.register_buffer('pe', pe) |
| | |
| | def forward(self, x): |
| | return x + self.pe[:x.size(1)] |
| |
|
| | class VoiceInterface: |
| | def __init__(self): |
| | self.recognizer = sr.Recognizer() |
| | self.engine = pyttsx3.init() |
| | |
| | def listen(self) -> Union[str, None]: |
| | with sr.Microphone() as source: |
| | print("Listening...") |
| | audio = self.recognizer.listen(source) |
| | try: |
| | text = self.recognizer.recognize_google(audio) |
| | print(f"You said: {text}") |
| | return text |
| | except Exception as e: |
| | print(f"Error recognizing speech: {e}") |
| | return None |
| | |
| | def speak(self, text: str): |
| | print(f"Bot: {text}") |
| | self.engine.say(text) |
| | self.engine.runAndWait() |
| |
|
| | class ChatBot: |
| | def __init__(self): |
| | self.dataset = FullChatDataset() |
| | self.model = SimpleTransformerModel(len(self.dataset.tokenizer)) |
| | self.voice_interface = VoiceInterface() |
| | self.web_searcher = WebSearchWrapper() |
| | |
| | def train(self, epochs=3, lr=3e-4): |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | self.model = self.model.to(device) |
| | criterion = nn.CrossEntropyLoss(ignore_index=0) |
| | optimizer = optim.Adam(self.model.parameters(), lr=lr) |
| | |
| | dataloader = DataLoader(self.dataset, batch_size=8, shuffle=True) |
| | |
| | for epoch in range(epochs): |
| | self.model.train() |
| | total_loss = 0 |
| | pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") |
| | |
| | for batch in pbar: |
| | inputs = batch['input_ids'].to(device) |
| | masks = batch['attention_mask'].to(device) |
| | labels = batch['labels'].to(device) |
| | |
| | optimizer.zero_grad() |
| | outputs = self.model(inputs, masks) |
| | loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1)) |
| | loss.backward() |
| | optimizer.step() |
| | |
| | total_loss += loss.item() |
| | pbar.set_postfix({'loss': loss.item()}) |
| | |
| | print(f"Epoch {epoch+1} - Avg loss: {total_loss/len(dataloader):.4f}") |
| | |
| | def generate_response(self, prompt: str, max_length: int = 100, use_web: bool = True) -> str: |
| | device = next(self.model.parameters()).device |
| | self.model.eval() |
| | |
| | |
| | if use_web and self._needs_web_search(prompt): |
| | web_results = self.web_searcher.search(prompt) |
| | if web_results: |
| | prompt = f"Web context: {', '.join(web_results[:3])}. User question: {prompt}" |
| | |
| | inputs = self.dataset.tokenizer( |
| | prompt, |
| | return_tensors="pt", |
| | max_length=256, |
| | truncation=True, |
| | padding='max_length' |
| | ).to(device) |
| | |
| | with torch.no_grad(): |
| | outputs = self.model.generate( |
| | input_ids=inputs['input_ids'], |
| | attention_mask=inputs['attention_mask'], |
| | max_length=max_length, |
| | do_sample=True, |
| | top_k=50, |
| | top_p=0.95, |
| | temperature=0.7 |
| | ) |
| | |
| | response = self.dataset.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | return response |
| | |
| | def _needs_web_search(self, text: str) -> bool: |
| | """Determine if a query needs web search""" |
| | question_words = ['what', 'when', 'where', 'who', 'why', 'how', 'which', '?'] |
| | return any(word in text.lower() for word in question_words) |