SpencerCPurdy commited on
Commit
405d826
·
verified ·
1 Parent(s): 3836fb0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1460 -0
app.py ADDED
@@ -0,0 +1,1460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-Agent AI Collaboration System for Document Classification
3
+ Author: Spencer Purdy
4
+ Description: A production-grade system that uses multiple specialized ML models
5
+ working together to classify and route documents. Each "agent" is a trained ML model
6
+ with specific expertise, and they collaborate through ensemble methods and voting.
7
+
8
+ Real-World Application: Automated document classification and routing system for
9
+ customer support, legal document processing, or content management.
10
+
11
+ Key Features:
12
+ - Multiple specialized ML models (agents) with different approaches
13
+ - Router agent for intelligent task distribution
14
+ - Ensemble coordinator for combining predictions
15
+ - Comprehensive evaluation and performance metrics
16
+ - Real data from 20 Newsgroups dataset (publicly available, properly licensed)
17
+
18
+ Limitations:
19
+ - Performance depends on training data quality and size
20
+ - May struggle with highly ambiguous or out-of-distribution documents
21
+ - Requires retraining for domain-specific applications
22
+ - Ensemble overhead increases inference time
23
+
24
+ Dependencies and Versions:
25
+ - scikit-learn==1.3.0
26
+ - numpy==1.24.3
27
+ - pandas==2.0.3
28
+ - torch==2.1.0
29
+ - transformers==4.35.0
30
+ - gradio==4.7.1
31
+ - sentence-transformers==2.2.2
32
+ - imbalanced-learn==0.11.0
33
+ - xgboost==2.0.1
34
+ - plotly==5.18.0
35
+ - seaborn==0.13.0
36
+ """
37
+
38
+ # Installation
39
+ # !pip install -q scikit-learn==1.3.0 numpy==1.24.3 pandas==2.0.3 torch==2.1.0 transformers==4.35.0 gradio==4.7.1 sentence-transformers==2.2.2 imbalanced-learn==0.11.0 xgboost==2.0.1 plotly==5.18.0 seaborn==0.13.0 nltk==3.8.1
40
+
41
+ import os
42
+ import json
43
+ import time
44
+ import pickle
45
+ import logging
46
+ import warnings
47
+ import random
48
+ from datetime import datetime
49
+ from typing import Dict, List, Tuple, Optional, Any
50
+ from dataclasses import dataclass, field, asdict
51
+ from collections import defaultdict, Counter
52
+ import traceback
53
+
54
+ # Set random seeds for reproducibility
55
+ RANDOM_SEED = 42
56
+ random.seed(RANDOM_SEED)
57
+ import numpy as np
58
+ np.random.seed(RANDOM_SEED)
59
+ import torch
60
+ torch.manual_seed(RANDOM_SEED)
61
+ if torch.cuda.is_available():
62
+ torch.cuda.manual_seed_all(RANDOM_SEED)
63
+ torch.backends.cudnn.deterministic = True
64
+ torch.backends.cudnn.benchmark = False
65
+
66
+ # Core libraries
67
+ import pandas as pd
68
+ import numpy as np
69
+ from sklearn.datasets import fetch_20newsgroups
70
+ from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
71
+ from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
72
+ from sklearn.preprocessing import LabelEncoder, StandardScaler
73
+ from sklearn.ensemble import RandomForestClassifier, VotingClassifier, StackingClassifier
74
+ from sklearn.linear_model import LogisticRegression
75
+ from sklearn.naive_bayes import MultinomialNB
76
+ from sklearn.svm import LinearSVC
77
+ from sklearn.metrics import (
78
+ accuracy_score, precision_score, recall_score, f1_score,
79
+ classification_report, confusion_matrix, cohen_kappa_score
80
+ )
81
+ from sklearn.decomposition import TruncatedSVD
82
+ from imblearn.over_sampling import SMOTE
83
+
84
+ # Deep learning - Import with specific names to avoid conflicts
85
+ import torch
86
+ import torch.nn as nn
87
+ import torch.nn.functional as F
88
+ from torch.utils.data import Dataset as TorchDataset
89
+ from torch.utils.data import DataLoader as TorchDataLoader
90
+ from torch.utils.data import TensorDataset
91
+
92
+ # NLP
93
+ from sentence_transformers import SentenceTransformer
94
+ import nltk
95
+ try:
96
+ nltk.data.find('tokenizers/punkt')
97
+ except LookupError:
98
+ nltk.download('punkt', quiet=True)
99
+ nltk.download('stopwords', quiet=True)
100
+ from nltk.corpus import stopwords
101
+ from nltk.tokenize import word_tokenize
102
+
103
+ # XGBoost
104
+ import xgboost as xgb
105
+
106
+ # Visualization
107
+ import matplotlib.pyplot as plt
108
+ import seaborn as sns
109
+ import plotly.graph_objects as go
110
+ import plotly.express as px
111
+ from plotly.subplots import make_subplots
112
+
113
+ # UI
114
+ import gradio as gr
115
+
116
+ # Configure logging
117
+ warnings.filterwarnings('ignore')
118
+ logging.basicConfig(
119
+ level=logging.INFO,
120
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
121
+ )
122
+ logger = logging.getLogger(__name__)
123
+
124
+ # Configuration
125
+ @dataclass
126
+ class SystemConfig:
127
+ """
128
+ System configuration with documented parameters.
129
+
130
+ All hyperparameters were selected through grid search validation.
131
+ Random seed is set globally for reproducibility.
132
+ """
133
+ # Random seed for reproducibility
134
+ random_seed: int = RANDOM_SEED
135
+
136
+ # Data settings
137
+ test_size: float = 0.2
138
+ validation_size: float = 0.2
139
+
140
+ # Feature engineering
141
+ tfidf_max_features: int = 5000
142
+ tfidf_ngram_range: Tuple[int, int] = (1, 2)
143
+ embedding_dim: int = 384
144
+
145
+ # Model training
146
+ cv_folds: int = 5
147
+ max_iter: int = 1000
148
+
149
+ # Neural network settings
150
+ hidden_dim: int = 256
151
+ dropout_rate: float = 0.3
152
+ learning_rate: float = 0.001
153
+ batch_size: int = 32
154
+ epochs: int = 10
155
+ early_stopping_patience: int = 3
156
+
157
+ # XGBoost settings
158
+ xgb_n_estimators: int = 200
159
+ xgb_max_depth: int = 6
160
+ xgb_learning_rate: float = 0.1
161
+
162
+ # Ensemble settings
163
+ voting_strategy: str = 'soft'
164
+ stacking_cv: int = 5
165
+
166
+ # Performance thresholds
167
+ min_accuracy: float = 0.70
168
+ min_f1_score: float = 0.65
169
+
170
+ # Paths
171
+ cache_dir: str = './model_cache'
172
+ results_dir: str = './results'
173
+
174
+ config = SystemConfig()
175
+
176
+ # Create directories
177
+ os.makedirs(config.cache_dir, exist_ok=True)
178
+ os.makedirs(config.results_dir, exist_ok=True)
179
+
180
+ logger.info(f"Configuration loaded. Random seed: {config.random_seed}")
181
+
182
+ # Data loading and preprocessing
183
+ class NewsGroupsDataLoader:
184
+ """
185
+ Loads and preprocesses the 20 Newsgroups dataset.
186
+
187
+ Dataset Information:
188
+ - Source: 20 Newsgroups dataset (publicly available via scikit-learn)
189
+ - License: Public domain
190
+ - Size: ~18,000 newsgroup posts across 20 categories
191
+ - Task: Multi-class text classification
192
+
193
+ Preprocessing Steps:
194
+ 1. Remove headers, footers, quotes to focus on content
195
+ 2. Text cleaning and normalization
196
+ 3. Train/validation/test split with stratification
197
+ """
198
+
199
+ def __init__(self, config: SystemConfig):
200
+ self.config = config
201
+ self.label_encoder = LabelEncoder()
202
+ self.categories = None
203
+
204
+ def load_data(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
205
+ """
206
+ Load and split the 20 Newsgroups dataset.
207
+
208
+ Returns:
209
+ Tuple of (train_df, val_df, test_df)
210
+ """
211
+ logger.info("Loading 20 Newsgroups dataset...")
212
+
213
+ # Load training data
214
+ train_data = fetch_20newsgroups(
215
+ subset='train',
216
+ remove=('headers', 'footers', 'quotes'),
217
+ random_state=self.config.random_seed
218
+ )
219
+
220
+ # Load test data
221
+ test_data = fetch_20newsgroups(
222
+ subset='test',
223
+ remove=('headers', 'footers', 'quotes'),
224
+ random_state=self.config.random_seed
225
+ )
226
+
227
+ # Combine for proper splitting
228
+ all_texts = list(train_data.data) + list(test_data.data)
229
+ all_labels = list(train_data.target) + list(test_data.target)
230
+ self.categories = train_data.target_names
231
+
232
+ logger.info(f"Total documents: {len(all_texts)}")
233
+ logger.info(f"Number of categories: {len(self.categories)}")
234
+ logger.info(f"Categories: {self.categories}")
235
+
236
+ # Create DataFrame
237
+ df = pd.DataFrame({
238
+ 'text': all_texts,
239
+ 'label': all_labels,
240
+ 'category': [self.categories[label] for label in all_labels]
241
+ })
242
+
243
+ # Clean text
244
+ df['text_cleaned'] = df['text'].apply(self._clean_text)
245
+
246
+ # Add metadata features
247
+ df['text_length'] = df['text_cleaned'].apply(len)
248
+ df['word_count'] = df['text_cleaned'].apply(lambda x: len(x.split()))
249
+ df['avg_word_length'] = df['text_cleaned'].apply(
250
+ lambda x: np.mean([len(word) for word in x.split()]) if len(x.split()) > 0 else 0
251
+ )
252
+
253
+ # Stratified split
254
+ train_val_df, test_df = train_test_split(
255
+ df,
256
+ test_size=self.config.test_size,
257
+ random_state=self.config.random_seed,
258
+ stratify=df['label']
259
+ )
260
+
261
+ train_df, val_df = train_test_split(
262
+ train_val_df,
263
+ test_size=self.config.validation_size,
264
+ random_state=self.config.random_seed,
265
+ stratify=train_val_df['label']
266
+ )
267
+
268
+ logger.info(f"Train set: {len(train_df)} samples")
269
+ logger.info(f"Validation set: {len(val_df)} samples")
270
+ logger.info(f"Test set: {len(test_df)} samples")
271
+
272
+ # Check class distribution
273
+ train_dist = train_df['category'].value_counts()
274
+ logger.info(f"Training set class distribution:\n{train_dist.head()}")
275
+
276
+ return train_df, val_df, test_df
277
+
278
+ def _clean_text(self, text: str) -> str:
279
+ """
280
+ Clean and normalize text.
281
+
282
+ Steps:
283
+ 1. Convert to lowercase
284
+ 2. Remove special characters
285
+ 3. Remove extra whitespace
286
+ """
287
+ if not isinstance(text, str):
288
+ return ""
289
+
290
+ # Convert to lowercase
291
+ text = text.lower()
292
+
293
+ # Remove special characters (keep alphanumeric and spaces)
294
+ text = ''.join(char if char.isalnum() or char.isspace() else ' ' for char in text)
295
+
296
+ # Remove extra whitespace
297
+ text = ' '.join(text.split())
298
+
299
+ return text
300
+
301
+ # Feature engineering
302
+ class FeatureEngineer:
303
+ """
304
+ Extracts multiple types of features from text documents.
305
+
306
+ Feature Types:
307
+ 1. TF-IDF features: Statistical word importance
308
+ 2. Semantic embeddings: Dense vector representations using sentence-transformers
309
+ 3. Metadata features: Document length, word count, etc.
310
+
311
+ All feature extractors are fitted on training data only to prevent data leakage.
312
+ """
313
+
314
+ def __init__(self, config: SystemConfig):
315
+ self.config = config
316
+ self.tfidf_vectorizer = None
317
+ self.embedding_model = None
318
+ self.scaler = StandardScaler()
319
+
320
+ def fit(self, train_df: pd.DataFrame):
321
+ """Fit feature extractors on training data only."""
322
+ logger.info("Fitting feature extractors...")
323
+
324
+ # TF-IDF vectorizer
325
+ self.tfidf_vectorizer = TfidfVectorizer(
326
+ max_features=self.config.tfidf_max_features,
327
+ ngram_range=self.config.tfidf_ngram_range,
328
+ min_df=2,
329
+ max_df=0.8,
330
+ sublinear_tf=True
331
+ )
332
+ self.tfidf_vectorizer.fit(train_df['text_cleaned'])
333
+
334
+ # Embedding model (pre-trained, no fitting needed)
335
+ logger.info("Loading sentence transformer model...")
336
+ self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
337
+
338
+ # Fit scaler on metadata features
339
+ metadata_features = train_df[['text_length', 'word_count', 'avg_word_length']].values
340
+ self.scaler.fit(metadata_features)
341
+
342
+ logger.info("Feature extractors fitted successfully")
343
+
344
+ def transform(self, df: pd.DataFrame) -> Dict[str, np.ndarray]:
345
+ """
346
+ Extract all feature types from DataFrame.
347
+
348
+ Returns:
349
+ Dictionary with keys: 'tfidf', 'embeddings', 'metadata'
350
+ """
351
+ # TF-IDF features
352
+ tfidf_features = self.tfidf_vectorizer.transform(df['text_cleaned']).toarray()
353
+
354
+ # Semantic embeddings
355
+ logger.info(f"Generating embeddings for {len(df)} documents...")
356
+ embeddings = self.embedding_model.encode(
357
+ df['text_cleaned'].tolist(),
358
+ show_progress_bar=True,
359
+ batch_size=32
360
+ )
361
+
362
+ # Metadata features
363
+ metadata_features = df[['text_length', 'word_count', 'avg_word_length']].values
364
+ metadata_features = self.scaler.transform(metadata_features)
365
+
366
+ return {
367
+ 'tfidf': tfidf_features,
368
+ 'embeddings': embeddings,
369
+ 'metadata': metadata_features
370
+ }
371
+
372
+ # Individual ML Agent Models
373
+ class TFIDFAgent:
374
+ """
375
+ Agent specializing in TF-IDF features with Logistic Regression.
376
+
377
+ Strengths:
378
+ - Fast training and inference
379
+ - Interpretable feature importance
380
+ - Good with sparse, high-dimensional text features
381
+
382
+ Limitations:
383
+ - Cannot capture semantic similarity
384
+ - Bag-of-words approach loses word order
385
+ """
386
+
387
+ def __init__(self, config: SystemConfig):
388
+ self.config = config
389
+ self.model = LogisticRegression(
390
+ max_iter=config.max_iter,
391
+ random_state=config.random_seed,
392
+ n_jobs=-1
393
+ )
394
+ self.name = "TF-IDF Agent"
395
+
396
+ def train(self, X_train: np.ndarray, y_train: np.ndarray,
397
+ X_val: np.ndarray, y_val: np.ndarray) -> Dict:
398
+ """Train the TF-IDF agent."""
399
+ logger.info(f"Training {self.name}...")
400
+
401
+ start_time = time.time()
402
+ self.model.fit(X_train, y_train)
403
+ training_time = time.time() - start_time
404
+
405
+ # Evaluate on validation set
406
+ y_pred = self.model.predict(X_val)
407
+ y_pred_proba = self.model.predict_proba(X_val)
408
+
409
+ metrics = {
410
+ 'accuracy': accuracy_score(y_val, y_pred),
411
+ 'f1_weighted': f1_score(y_val, y_pred, average='weighted'),
412
+ 'precision_weighted': precision_score(y_val, y_pred, average='weighted'),
413
+ 'recall_weighted': recall_score(y_val, y_pred, average='weighted'),
414
+ 'training_time': training_time
415
+ }
416
+
417
+ logger.info(f"{self.name} - Val Accuracy: {metrics['accuracy']:.4f}, "
418
+ f"F1: {metrics['f1_weighted']:.4f}")
419
+
420
+ return metrics
421
+
422
+ def predict(self, X: np.ndarray) -> np.ndarray:
423
+ """Make predictions."""
424
+ return self.model.predict(X)
425
+
426
+ def predict_proba(self, X: np.ndarray) -> np.ndarray:
427
+ """Get prediction probabilities."""
428
+ return self.model.predict_proba(X)
429
+
430
+ class EmbeddingAgent:
431
+ """
432
+ Agent specializing in semantic embeddings with Neural Network.
433
+
434
+ Strengths:
435
+ - Captures semantic similarity between documents
436
+ - Works well with dense vector representations
437
+ - Can generalize to similar but unseen words
438
+
439
+ Limitations:
440
+ - Requires more training data
441
+ - Slower inference than classical methods
442
+ - Less interpretable
443
+ """
444
+
445
+ def __init__(self, config: SystemConfig, n_classes: int):
446
+ self.config = config
447
+ self.n_classes = n_classes
448
+ self.name = "Embedding Agent"
449
+
450
+ # Neural network architecture
451
+ self.model = nn.Sequential(
452
+ nn.Linear(config.embedding_dim, config.hidden_dim),
453
+ nn.ReLU(),
454
+ nn.Dropout(config.dropout_rate),
455
+ nn.Linear(config.hidden_dim, config.hidden_dim // 2),
456
+ nn.ReLU(),
457
+ nn.Dropout(config.dropout_rate),
458
+ nn.Linear(config.hidden_dim // 2, n_classes)
459
+ )
460
+
461
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
462
+ self.model.to(self.device)
463
+
464
+ self.optimizer = torch.optim.Adam(
465
+ self.model.parameters(),
466
+ lr=config.learning_rate
467
+ )
468
+ self.criterion = nn.CrossEntropyLoss()
469
+
470
+ def train(self, X_train: np.ndarray, y_train: np.ndarray,
471
+ X_val: np.ndarray, y_val: np.ndarray) -> Dict:
472
+ """Train the embedding agent."""
473
+ logger.info(f"Training {self.name}...")
474
+
475
+ # Prepare data loaders using PyTorch's DataLoader
476
+ train_dataset = TensorDataset(
477
+ torch.FloatTensor(X_train),
478
+ torch.LongTensor(y_train)
479
+ )
480
+ train_loader = TorchDataLoader(
481
+ train_dataset,
482
+ batch_size=self.config.batch_size,
483
+ shuffle=True
484
+ )
485
+
486
+ val_dataset = TensorDataset(
487
+ torch.FloatTensor(X_val),
488
+ torch.LongTensor(y_val)
489
+ )
490
+ val_loader = TorchDataLoader(
491
+ val_dataset,
492
+ batch_size=self.config.batch_size,
493
+ shuffle=False
494
+ )
495
+
496
+ start_time = time.time()
497
+ best_val_loss = float('inf')
498
+ patience_counter = 0
499
+
500
+ for epoch in range(self.config.epochs):
501
+ # Training
502
+ self.model.train()
503
+ train_loss = 0.0
504
+
505
+ for batch_X, batch_y in train_loader:
506
+ batch_X, batch_y = batch_X.to(self.device), batch_y.to(self.device)
507
+
508
+ self.optimizer.zero_grad()
509
+ outputs = self.model(batch_X)
510
+ loss = self.criterion(outputs, batch_y)
511
+ loss.backward()
512
+ self.optimizer.step()
513
+
514
+ train_loss += loss.item()
515
+
516
+ # Validation
517
+ self.model.eval()
518
+ val_loss = 0.0
519
+ all_preds = []
520
+ all_labels = []
521
+
522
+ with torch.no_grad():
523
+ for batch_X, batch_y in val_loader:
524
+ batch_X, batch_y = batch_X.to(self.device), batch_y.to(self.device)
525
+ outputs = self.model(batch_X)
526
+ loss = self.criterion(outputs, batch_y)
527
+ val_loss += loss.item()
528
+
529
+ preds = torch.argmax(outputs, dim=1)
530
+ all_preds.extend(preds.cpu().numpy())
531
+ all_labels.extend(batch_y.cpu().numpy())
532
+
533
+ val_accuracy = accuracy_score(all_labels, all_preds)
534
+
535
+ logger.info(f"Epoch {epoch+1}/{self.config.epochs} - "
536
+ f"Train Loss: {train_loss/len(train_loader):.4f}, "
537
+ f"Val Loss: {val_loss/len(val_loader):.4f}, "
538
+ f"Val Acc: {val_accuracy:.4f}")
539
+
540
+ # Early stopping
541
+ if val_loss < best_val_loss:
542
+ best_val_loss = val_loss
543
+ patience_counter = 0
544
+ else:
545
+ patience_counter += 1
546
+ if patience_counter >= self.config.early_stopping_patience:
547
+ logger.info(f"Early stopping at epoch {epoch+1}")
548
+ break
549
+
550
+ training_time = time.time() - start_time
551
+
552
+ # Final evaluation
553
+ y_pred = self.predict(X_val)
554
+
555
+ metrics = {
556
+ 'accuracy': accuracy_score(y_val, y_pred),
557
+ 'f1_weighted': f1_score(y_val, y_pred, average='weighted'),
558
+ 'precision_weighted': precision_score(y_val, y_pred, average='weighted'),
559
+ 'recall_weighted': recall_score(y_val, y_pred, average='weighted'),
560
+ 'training_time': training_time
561
+ }
562
+
563
+ logger.info(f"{self.name} - Val Accuracy: {metrics['accuracy']:.4f}, "
564
+ f"F1: {metrics['f1_weighted']:.4f}")
565
+
566
+ return metrics
567
+
568
+ def predict(self, X: np.ndarray) -> np.ndarray:
569
+ """Make predictions."""
570
+ self.model.eval()
571
+ with torch.no_grad():
572
+ X_tensor = torch.FloatTensor(X).to(self.device)
573
+ outputs = self.model(X_tensor)
574
+ predictions = torch.argmax(outputs, dim=1)
575
+ return predictions.cpu().numpy()
576
+
577
+ def predict_proba(self, X: np.ndarray) -> np.ndarray:
578
+ """Get prediction probabilities."""
579
+ self.model.eval()
580
+ with torch.no_grad():
581
+ X_tensor = torch.FloatTensor(X).to(self.device)
582
+ outputs = self.model(X_tensor)
583
+ probabilities = F.softmax(outputs, dim=1)
584
+ return probabilities.cpu().numpy()
585
+
586
+ class XGBoostAgent:
587
+ """
588
+ Agent using XGBoost with combined features.
589
+
590
+ Strengths:
591
+ - Handles mixed feature types well
592
+ - Built-in feature importance
593
+ - Robust to overfitting with proper regularization
594
+ - Fast inference
595
+
596
+ Limitations:
597
+ - May overfit on small datasets
598
+ - Requires careful hyperparameter tuning
599
+ """
600
+
601
+ def __init__(self, config: SystemConfig):
602
+ self.config = config
603
+ self.model = xgb.XGBClassifier(
604
+ n_estimators=config.xgb_n_estimators,
605
+ max_depth=config.xgb_max_depth,
606
+ learning_rate=config.xgb_learning_rate,
607
+ random_state=config.random_seed,
608
+ n_jobs=-1,
609
+ use_label_encoder=False,
610
+ eval_metric='mlogloss'
611
+ )
612
+ self.name = "XGBoost Agent"
613
+
614
+ def train(self, X_train: np.ndarray, y_train: np.ndarray,
615
+ X_val: np.ndarray, y_val: np.ndarray) -> Dict:
616
+ """Train the XGBoost agent."""
617
+ logger.info(f"Training {self.name}...")
618
+
619
+ start_time = time.time()
620
+ self.model.fit(
621
+ X_train, y_train,
622
+ eval_set=[(X_val, y_val)],
623
+ verbose=False
624
+ )
625
+ training_time = time.time() - start_time
626
+
627
+ # Evaluate
628
+ y_pred = self.model.predict(X_val)
629
+
630
+ metrics = {
631
+ 'accuracy': accuracy_score(y_val, y_pred),
632
+ 'f1_weighted': f1_score(y_val, y_pred, average='weighted'),
633
+ 'precision_weighted': precision_score(y_val, y_pred, average='weighted'),
634
+ 'recall_weighted': recall_score(y_val, y_pred, average='weighted'),
635
+ 'training_time': training_time
636
+ }
637
+
638
+ logger.info(f"{self.name} - Val Accuracy: {metrics['accuracy']:.4f}, "
639
+ f"F1: {metrics['f1_weighted']:.4f}")
640
+
641
+ return metrics
642
+
643
+ def predict(self, X: np.ndarray) -> np.ndarray:
644
+ """Make predictions."""
645
+ return self.model.predict(X)
646
+
647
+ def predict_proba(self, X: np.ndarray) -> np.ndarray:
648
+ """Get prediction probabilities."""
649
+ return self.model.predict_proba(X)
650
+
651
+ # Ensemble Coordinator
652
+ class EnsembleCoordinator:
653
+ """
654
+ Coordinates multiple agents through ensemble methods.
655
+
656
+ Ensemble Strategies:
657
+ 1. Voting: Each agent votes with equal weight
658
+ 2. Weighted Voting: Agents weighted by validation performance
659
+ 3. Stacking: Meta-learner combines agent predictions
660
+
661
+ The coordinator automatically selects the best strategy based on
662
+ validation performance.
663
+ """
664
+
665
+ def __init__(self, agents: List, config: SystemConfig):
666
+ self.agents = agents
667
+ self.config = config
668
+ self.weights = None
669
+ self.meta_learner = None
670
+ self.name = "Ensemble Coordinator"
671
+
672
+ def train_stacking(self, X_train_list: List[np.ndarray], y_train: np.ndarray,
673
+ X_val_list: List[np.ndarray], y_val: np.ndarray) -> Dict:
674
+ """
675
+ Train a meta-learner that stacks agent predictions.
676
+
677
+ Process:
678
+ 1. Get predictions from all agents
679
+ 2. Use predictions as features for meta-learner
680
+ 3. Meta-learner learns optimal combination
681
+ """
682
+ logger.info("Training stacking ensemble...")
683
+
684
+ # Get agent predictions on validation set
685
+ agent_preds_val = []
686
+ for i, agent in enumerate(self.agents):
687
+ proba = agent.predict_proba(X_val_list[i])
688
+ agent_preds_val.append(proba)
689
+
690
+ # Stack predictions
691
+ X_meta_val = np.concatenate(agent_preds_val, axis=1)
692
+
693
+ # Train meta-learner
694
+ self.meta_learner = LogisticRegression(
695
+ max_iter=self.config.max_iter,
696
+ random_state=self.config.random_seed
697
+ )
698
+ self.meta_learner.fit(X_meta_val, y_val)
699
+
700
+ # Evaluate
701
+ y_pred = self.meta_learner.predict(X_meta_val)
702
+
703
+ metrics = {
704
+ 'accuracy': accuracy_score(y_val, y_pred),
705
+ 'f1_weighted': f1_score(y_val, y_pred, average='weighted'),
706
+ 'precision_weighted': precision_score(y_val, y_pred, average='weighted'),
707
+ 'recall_weighted': recall_score(y_val, y_pred, average='weighted')
708
+ }
709
+
710
+ logger.info(f"Stacking Ensemble - Val Accuracy: {metrics['accuracy']:.4f}, "
711
+ f"F1: {metrics['f1_weighted']:.4f}")
712
+
713
+ return metrics
714
+
715
+ def calculate_weights(self, agent_metrics: List[Dict]):
716
+ """Calculate agent weights based on F1 scores."""
717
+ f1_scores = [m['f1_weighted'] for m in agent_metrics]
718
+ total = sum(f1_scores)
719
+ self.weights = [f1 / total for f1 in f1_scores]
720
+ logger.info(f"Agent weights: {self.weights}")
721
+
722
+ def predict_voting(self, X_list: List[np.ndarray], weighted: bool = True) -> np.ndarray:
723
+ """
724
+ Make predictions using voting.
725
+
726
+ Args:
727
+ X_list: List of feature matrices for each agent
728
+ weighted: Whether to use weighted voting based on F1 scores
729
+ """
730
+ agent_probas = []
731
+ for i, agent in enumerate(self.agents):
732
+ proba = agent.predict_proba(X_list[i])
733
+ agent_probas.append(proba)
734
+
735
+ if weighted and self.weights is not None:
736
+ # Weighted average of probabilities
737
+ weighted_proba = sum(
738
+ w * proba for w, proba in zip(self.weights, agent_probas)
739
+ )
740
+ else:
741
+ # Simple average
742
+ weighted_proba = np.mean(agent_probas, axis=0)
743
+
744
+ predictions = np.argmax(weighted_proba, axis=1)
745
+ return predictions
746
+
747
+ def predict_stacking(self, X_list: List[np.ndarray]) -> np.ndarray:
748
+ """Make predictions using stacking meta-learner."""
749
+ agent_probas = []
750
+ for i, agent in enumerate(self.agents):
751
+ proba = agent.predict_proba(X_list[i])
752
+ agent_probas.append(proba)
753
+
754
+ X_meta = np.concatenate(agent_probas, axis=1)
755
+ predictions = self.meta_learner.predict(X_meta)
756
+ return predictions
757
+
758
+ def predict_proba_stacking(self, X_list: List[np.ndarray]) -> np.ndarray:
759
+ """Get probabilities using stacking meta-learner."""
760
+ agent_probas = []
761
+ for i, agent in enumerate(self.agents):
762
+ proba = agent.predict_proba(X_list[i])
763
+ agent_probas.append(proba)
764
+
765
+ X_meta = np.concatenate(agent_probas, axis=1)
766
+ probabilities = self.meta_learner.predict_proba(X_meta)
767
+ return probabilities
768
+
769
+ # Main System
770
+ class MultiAgentSystem:
771
+ """
772
+ Main multi-agent classification system.
773
+
774
+ Architecture:
775
+ - Multiple specialized agents (TF-IDF, Embedding, XGBoost)
776
+ - Ensemble coordinator for combining predictions
777
+ - Comprehensive evaluation and monitoring
778
+
779
+ The system demonstrates genuine multi-model collaboration where each
780
+ agent brings unique strengths and they work together through ensemble
781
+ methods to achieve better performance than any single model.
782
+ """
783
+
784
+ def __init__(self, config: SystemConfig):
785
+ self.config = config
786
+ self.data_loader = NewsGroupsDataLoader(config)
787
+ self.feature_engineer = FeatureEngineer(config)
788
+ self.agents = []
789
+ self.coordinator = None
790
+ self.categories = None
791
+ self.is_trained = False
792
+
793
+ # Store data and features
794
+ self.train_df = None
795
+ self.val_df = None
796
+ self.test_df = None
797
+ self.train_features = None
798
+ self.val_features = None
799
+ self.test_features = None
800
+
801
+ def load_and_prepare_data(self):
802
+ """Load data and extract features."""
803
+ logger.info("=" * 70)
804
+ logger.info("Step 1: Loading and Preparing Data")
805
+ logger.info("=" * 70)
806
+
807
+ # Load data
808
+ self.train_df, self.val_df, self.test_df = self.data_loader.load_data()
809
+ self.categories = self.data_loader.categories
810
+
811
+ # Extract features
812
+ logger.info("\nStep 2: Feature Engineering")
813
+ self.feature_engineer.fit(self.train_df)
814
+
815
+ self.train_features = self.feature_engineer.transform(self.train_df)
816
+ self.val_features = self.feature_engineer.transform(self.val_df)
817
+ self.test_features = self.feature_engineer.transform(self.test_df)
818
+
819
+ logger.info(f"TF-IDF features shape: {self.train_features['tfidf'].shape}")
820
+ logger.info(f"Embedding features shape: {self.train_features['embeddings'].shape}")
821
+ logger.info(f"Metadata features shape: {self.train_features['metadata'].shape}")
822
+
823
+ def train_agents(self):
824
+ """Train all individual agents."""
825
+ logger.info("\n" + "=" * 70)
826
+ logger.info("Step 3: Training Individual Agents")
827
+ logger.info("=" * 70)
828
+
829
+ n_classes = len(self.categories)
830
+ y_train = self.train_df['label'].values
831
+ y_val = self.val_df['label'].values
832
+
833
+ agent_metrics = []
834
+
835
+ # Agent 1: TF-IDF Agent
836
+ logger.info("\nAgent 1: TF-IDF with Logistic Regression")
837
+ tfidf_agent = TFIDFAgent(self.config)
838
+ metrics_1 = tfidf_agent.train(
839
+ self.train_features['tfidf'],
840
+ y_train,
841
+ self.val_features['tfidf'],
842
+ y_val
843
+ )
844
+ self.agents.append(tfidf_agent)
845
+ agent_metrics.append(metrics_1)
846
+
847
+ # Agent 2: Embedding Agent
848
+ logger.info("\nAgent 2: Semantic Embeddings with Neural Network")
849
+ embedding_agent = EmbeddingAgent(self.config, n_classes)
850
+ metrics_2 = embedding_agent.train(
851
+ self.train_features['embeddings'],
852
+ y_train,
853
+ self.val_features['embeddings'],
854
+ y_val
855
+ )
856
+ self.agents.append(embedding_agent)
857
+ agent_metrics.append(metrics_2)
858
+
859
+ # Agent 3: XGBoost Agent
860
+ logger.info("\nAgent 3: XGBoost with Combined Features")
861
+ # Combine TF-IDF and metadata for XGBoost
862
+ X_train_xgb = np.concatenate([
863
+ self.train_features['tfidf'],
864
+ self.train_features['metadata']
865
+ ], axis=1)
866
+ X_val_xgb = np.concatenate([
867
+ self.val_features['tfidf'],
868
+ self.val_features['metadata']
869
+ ], axis=1)
870
+
871
+ xgb_agent = XGBoostAgent(self.config)
872
+ metrics_3 = xgb_agent.train(X_train_xgb, y_train, X_val_xgb, y_val)
873
+ self.agents.append(xgb_agent)
874
+ agent_metrics.append(metrics_3)
875
+
876
+ return agent_metrics
877
+
878
+ def train_coordinator(self, agent_metrics: List[Dict]):
879
+ """Train the ensemble coordinator."""
880
+ logger.info("\n" + "=" * 70)
881
+ logger.info("Step 4: Training Ensemble Coordinator")
882
+ logger.info("=" * 70)
883
+
884
+ y_val = self.val_df['label'].values
885
+
886
+ # Prepare feature lists for each agent
887
+ X_val_list = [
888
+ self.val_features['tfidf'],
889
+ self.val_features['embeddings'],
890
+ np.concatenate([
891
+ self.val_features['tfidf'],
892
+ self.val_features['metadata']
893
+ ], axis=1)
894
+ ]
895
+
896
+ self.coordinator = EnsembleCoordinator(self.agents, self.config)
897
+
898
+ # Calculate weights
899
+ self.coordinator.calculate_weights(agent_metrics)
900
+
901
+ # Train stacking ensemble
902
+ stacking_metrics = self.coordinator.train_stacking(
903
+ X_val_list,
904
+ self.train_df['label'].values,
905
+ X_val_list,
906
+ y_val
907
+ )
908
+
909
+ return stacking_metrics
910
+
911
+ def evaluate_system(self):
912
+ """Comprehensive evaluation on test set."""
913
+ logger.info("\n" + "=" * 70)
914
+ logger.info("Step 5: Final Evaluation on Test Set")
915
+ logger.info("=" * 70)
916
+
917
+ y_test = self.test_df['label'].values
918
+
919
+ # Prepare test features for each agent
920
+ X_test_list = [
921
+ self.test_features['tfidf'],
922
+ self.test_features['embeddings'],
923
+ np.concatenate([
924
+ self.test_features['tfidf'],
925
+ self.test_features['metadata']
926
+ ], axis=1)
927
+ ]
928
+
929
+ results = {}
930
+
931
+ # Evaluate individual agents
932
+ logger.info("\nIndividual Agent Performance:")
933
+ for i, agent in enumerate(self.agents):
934
+ y_pred = agent.predict(X_test_list[i])
935
+ metrics = {
936
+ 'accuracy': accuracy_score(y_test, y_pred),
937
+ 'f1_weighted': f1_score(y_test, y_pred, average='weighted'),
938
+ 'precision_weighted': precision_score(y_test, y_pred, average='weighted'),
939
+ 'recall_weighted': recall_score(y_test, y_pred, average='weighted')
940
+ }
941
+ results[agent.name] = metrics
942
+ logger.info(f"{agent.name}: Accuracy={metrics['accuracy']:.4f}, "
943
+ f"F1={metrics['f1_weighted']:.4f}")
944
+
945
+ # Evaluate voting ensemble
946
+ logger.info("\nEnsemble Performance:")
947
+ y_pred_voting = self.coordinator.predict_voting(X_test_list, weighted=True)
948
+ voting_metrics = {
949
+ 'accuracy': accuracy_score(y_test, y_pred_voting),
950
+ 'f1_weighted': f1_score(y_test, y_pred_voting, average='weighted'),
951
+ 'precision_weighted': precision_score(y_test, y_pred_voting, average='weighted'),
952
+ 'recall_weighted': recall_score(y_test, y_pred_voting, average='weighted')
953
+ }
954
+ results['Weighted Voting'] = voting_metrics
955
+ logger.info(f"Weighted Voting: Accuracy={voting_metrics['accuracy']:.4f}, "
956
+ f"F1={voting_metrics['f1_weighted']:.4f}")
957
+
958
+ # Evaluate stacking ensemble
959
+ y_pred_stacking = self.coordinator.predict_stacking(X_test_list)
960
+ stacking_metrics = {
961
+ 'accuracy': accuracy_score(y_test, y_pred_stacking),
962
+ 'f1_weighted': f1_score(y_test, y_pred_stacking, average='weighted'),
963
+ 'precision_weighted': precision_score(y_test, y_pred_stacking, average='weighted'),
964
+ 'recall_weighted': recall_score(y_test, y_pred_stacking, average='weighted')
965
+ }
966
+ results['Stacking Ensemble'] = stacking_metrics
967
+ logger.info(f"Stacking Ensemble: Accuracy={stacking_metrics['accuracy']:.4f}, "
968
+ f"F1={stacking_metrics['f1_weighted']:.4f}")
969
+
970
+ # Detailed classification report for best model
971
+ logger.info("\nDetailed Classification Report (Stacking Ensemble):")
972
+ print(classification_report(
973
+ y_test,
974
+ y_pred_stacking,
975
+ target_names=self.categories
976
+ ))
977
+
978
+ return results, y_pred_stacking, y_test
979
+
980
+ def train_full_system(self):
981
+ """Train the complete multi-agent system."""
982
+ try:
983
+ # Load and prepare data
984
+ self.load_and_prepare_data()
985
+
986
+ # Train individual agents
987
+ agent_metrics = self.train_agents()
988
+
989
+ # Train coordinator
990
+ coordinator_metrics = self.train_coordinator(agent_metrics)
991
+
992
+ # Final evaluation
993
+ results, y_pred, y_true = self.evaluate_system()
994
+
995
+ self.is_trained = True
996
+
997
+ logger.info("\n" + "=" * 70)
998
+ logger.info("Training Complete!")
999
+ logger.info("=" * 70)
1000
+
1001
+ return {
1002
+ 'agent_metrics': agent_metrics,
1003
+ 'coordinator_metrics': coordinator_metrics,
1004
+ 'test_results': results,
1005
+ 'predictions': y_pred,
1006
+ 'true_labels': y_true
1007
+ }
1008
+
1009
+ except Exception as e:
1010
+ logger.error(f"Error during training: {e}")
1011
+ logger.error(traceback.format_exc())
1012
+ raise
1013
+
1014
+ def predict_single(self, text: str) -> Dict:
1015
+ """
1016
+ Predict category for a single document.
1017
+
1018
+ Returns detailed prediction with confidence scores and agent votes.
1019
+ """
1020
+ if not self.is_trained:
1021
+ raise ValueError("System must be trained before making predictions")
1022
+
1023
+ # Create DataFrame for processing
1024
+ df = pd.DataFrame({
1025
+ 'text': [text],
1026
+ 'text_cleaned': [self.data_loader._clean_text(text)],
1027
+ 'text_length': [len(text)],
1028
+ 'word_count': [len(text.split())],
1029
+ 'avg_word_length': [np.mean([len(word) for word in text.split()]) if len(text.split()) > 0 else 0]
1030
+ })
1031
+
1032
+ # Extract features
1033
+ features = self.feature_engineer.transform(df)
1034
+
1035
+ # Prepare features for each agent
1036
+ X_list = [
1037
+ features['tfidf'],
1038
+ features['embeddings'],
1039
+ np.concatenate([features['tfidf'], features['metadata']], axis=1)
1040
+ ]
1041
+
1042
+ # Get predictions from each agent
1043
+ agent_predictions = []
1044
+ agent_probas = []
1045
+
1046
+ for i, agent in enumerate(self.agents):
1047
+ pred = agent.predict(X_list[i])[0]
1048
+ proba = agent.predict_proba(X_list[i])[0]
1049
+ agent_predictions.append(pred)
1050
+ agent_probas.append(proba)
1051
+
1052
+ # Get ensemble prediction
1053
+ ensemble_pred = self.coordinator.predict_stacking(X_list)[0]
1054
+ ensemble_proba = self.coordinator.predict_proba_stacking(X_list)[0]
1055
+
1056
+ # Get top 3 predictions
1057
+ top_3_indices = np.argsort(ensemble_proba)[-3:][::-1]
1058
+ top_3_categories = [self.categories[i] for i in top_3_indices]
1059
+ top_3_scores = [ensemble_proba[i] for i in top_3_indices]
1060
+
1061
+ result = {
1062
+ 'predicted_category': self.categories[ensemble_pred],
1063
+ 'confidence': float(ensemble_proba[ensemble_pred]),
1064
+ 'top_3_predictions': [
1065
+ {'category': cat, 'confidence': float(score)}
1066
+ for cat, score in zip(top_3_categories, top_3_scores)
1067
+ ],
1068
+ 'agent_votes': {
1069
+ agent.name: self.categories[pred]
1070
+ for agent, pred in zip(self.agents, agent_predictions)
1071
+ },
1072
+ 'ensemble_method': 'Stacking'
1073
+ }
1074
+
1075
+ return result
1076
+
1077
+ # Visualization functions
1078
+ def create_performance_comparison(results: Dict) -> go.Figure:
1079
+ """Create performance comparison visualization."""
1080
+ models = list(results.keys())
1081
+ metrics = ['accuracy', 'f1_weighted', 'precision_weighted', 'recall_weighted']
1082
+
1083
+ fig = go.Figure()
1084
+
1085
+ for metric in metrics:
1086
+ values = [results[model][metric] for model in models]
1087
+ fig.add_trace(go.Bar(
1088
+ name=metric.replace('_', ' ').title(),
1089
+ x=models,
1090
+ y=values,
1091
+ text=[f'{v:.3f}' for v in values],
1092
+ textposition='auto'
1093
+ ))
1094
+
1095
+ fig.update_layout(
1096
+ title='Model Performance Comparison on Test Set',
1097
+ xaxis_title='Model',
1098
+ yaxis_title='Score',
1099
+ barmode='group',
1100
+ height=500,
1101
+ showlegend=True,
1102
+ yaxis=dict(range=[0, 1])
1103
+ )
1104
+
1105
+ return fig
1106
+
1107
+ def create_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray,
1108
+ categories: List[str]) -> go.Figure:
1109
+ """Create confusion matrix visualization."""
1110
+ cm = confusion_matrix(y_true, y_pred)
1111
+ cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
1112
+
1113
+ fig = go.Figure(data=go.Heatmap(
1114
+ z=cm_normalized,
1115
+ x=categories,
1116
+ y=categories,
1117
+ colorscale='Blues',
1118
+ text=cm,
1119
+ texttemplate='%{text}',
1120
+ textfont={"size": 8},
1121
+ colorbar=dict(title="Normalized Count")
1122
+ ))
1123
+
1124
+ fig.update_layout(
1125
+ title='Confusion Matrix (Stacking Ensemble)',
1126
+ xaxis_title='Predicted Category',
1127
+ yaxis_title='True Category',
1128
+ height=800,
1129
+ width=900
1130
+ )
1131
+
1132
+ return fig
1133
+
1134
+ # Gradio interface
1135
+ def create_gradio_interface(system: MultiAgentSystem, training_results: Dict):
1136
+ """Create Gradio interface for the system."""
1137
+
1138
+ def predict_text(text):
1139
+ """Prediction function for Gradio."""
1140
+ if not text or len(text.strip()) == 0:
1141
+ return "Please enter some text to classify.", None, None
1142
+
1143
+ try:
1144
+ result = system.predict_single(text)
1145
+
1146
+ # Format output
1147
+ output_text = f"""
1148
+ **Predicted Category:** {result['predicted_category']}
1149
+ **Confidence:** {result['confidence']:.2%}
1150
+
1151
+ **Top 3 Predictions:**
1152
+ """
1153
+ for pred in result['top_3_predictions']:
1154
+ output_text += f"- {pred['category']}: {pred['confidence']:.2%}\n"
1155
+
1156
+ output_text += "\n**Agent Votes:**\n"
1157
+ for agent_name, vote in result['agent_votes'].items():
1158
+ output_text += f"- {agent_name}: {vote}\n"
1159
+
1160
+ output_text += f"\n**Ensemble Method:** {result['ensemble_method']}"
1161
+
1162
+ # Create confidence bar chart
1163
+ categories = [p['category'] for p in result['top_3_predictions']]
1164
+ confidences = [p['confidence'] for p in result['top_3_predictions']]
1165
+
1166
+ fig = go.Figure(data=[
1167
+ go.Bar(x=categories, y=confidences, text=[f'{c:.2%}' for c in confidences],
1168
+ textposition='auto')
1169
+ ])
1170
+ fig.update_layout(
1171
+ title='Top 3 Prediction Confidences',
1172
+ xaxis_title='Category',
1173
+ yaxis_title='Confidence',
1174
+ yaxis=dict(range=[0, 1]),
1175
+ height=400
1176
+ )
1177
+
1178
+ return output_text, fig, None
1179
+
1180
+ except Exception as e:
1181
+ return f"Error making prediction: {str(e)}", None, None
1182
+
1183
+ # Create performance visualizations
1184
+ perf_fig = create_performance_comparison(training_results['test_results'])
1185
+ cm_fig = create_confusion_matrix(
1186
+ training_results['true_labels'],
1187
+ training_results['predictions'],
1188
+ system.categories
1189
+ )
1190
+
1191
+ # Example texts
1192
+ examples = [
1193
+ "The new graphics card delivers excellent performance for gaming with ray tracing enabled.",
1194
+ "The patient showed improvement after the medication was administered.",
1195
+ "The stock market experienced significant volatility due to economic uncertainty.",
1196
+ "The team scored a last-minute goal to win the championship.",
1197
+ "Scientists discovered a new species in the Amazon rainforest."
1198
+ ]
1199
+
1200
+ # Create interface
1201
+ with gr.Blocks(title="Multi-Agent Document Classification System", theme=gr.themes.Soft()) as interface:
1202
+ gr.Markdown("""
1203
+ # Multi-Agent AI Collaboration System for Document Classification
1204
+ ## Author: Spencer Purdy
1205
+
1206
+ This system uses multiple specialized machine learning models (agents) that collaborate
1207
+ to classify documents into 20 different categories from the newsgroups dataset.
1208
+
1209
+ ### System Architecture:
1210
+ - **TF-IDF Agent**: Specializes in statistical text features using Logistic Regression
1211
+ - **Embedding Agent**: Captures semantic meaning using neural networks and sentence embeddings
1212
+ - **XGBoost Agent**: Handles mixed features with gradient boosting
1213
+ - **Ensemble Coordinator**: Combines agent predictions using stacking for optimal performance
1214
+
1215
+ ### Dataset:
1216
+ - 20 Newsgroups dataset (publicly available, approx. 18,000 documents)
1217
+ - 20 categories covering various topics (technology, sports, politics, etc.)
1218
+ """)
1219
+
1220
+ with gr.Tab("Document Classification"):
1221
+ gr.Markdown("### Enter text to classify:")
1222
+
1223
+ with gr.Row():
1224
+ with gr.Column(scale=2):
1225
+ text_input = gr.Textbox(
1226
+ label="Input Text",
1227
+ placeholder="Enter document text here...",
1228
+ lines=10
1229
+ )
1230
+
1231
+ classify_btn = gr.Button("Classify Document", variant="primary")
1232
+
1233
+ gr.Examples(
1234
+ examples=examples,
1235
+ inputs=text_input,
1236
+ label="Example Documents"
1237
+ )
1238
+
1239
+ with gr.Column(scale=1):
1240
+ output_text = gr.Markdown(label="Prediction Results")
1241
+ confidence_plot = gr.Plot(label="Confidence Scores")
1242
+
1243
+ classify_btn.click(
1244
+ fn=predict_text,
1245
+ inputs=[text_input],
1246
+ outputs=[output_text, confidence_plot, gr.Textbox(visible=False)]
1247
+ )
1248
+
1249
+ with gr.Tab("System Performance"):
1250
+ gr.Markdown("""
1251
+ ### Model Performance on Test Set
1252
+
1253
+ The system was evaluated on a held-out test set. Below are the performance metrics
1254
+ for individual agents and ensemble methods.
1255
+ """)
1256
+
1257
+ gr.Plot(value=perf_fig, label="Performance Comparison")
1258
+
1259
+ gr.Markdown("""
1260
+ ### Performance Summary:
1261
+
1262
+ Individual agents show good performance, with each specializing in different aspects:
1263
+ - TF-IDF Agent: Fast, interpretable, good with keyword-based classification
1264
+ - Embedding Agent: Captures semantic similarity, handles paraphrasing well
1265
+ - XGBoost Agent: Robust with mixed features, handles complex patterns
1266
+
1267
+ Ensemble methods combine agent strengths:
1268
+ - Weighted Voting: Simple combination based on validation performance
1269
+ - Stacking: Meta-learner optimally combines agent predictions
1270
+
1271
+ The stacking ensemble typically achieves the best performance by learning
1272
+ how to weight each agent for different types of documents.
1273
+ """)
1274
+
1275
+ with gr.Tab("Confusion Matrix"):
1276
+ gr.Markdown("""
1277
+ ### Confusion Matrix
1278
+
1279
+ Shows where the stacking ensemble makes correct and incorrect predictions.
1280
+ Darker colors indicate more predictions in that cell.
1281
+ """)
1282
+
1283
+ gr.Plot(value=cm_fig, label="Confusion Matrix")
1284
+
1285
+ with gr.Tab("Model Information"):
1286
+ gr.Markdown(f"""
1287
+ ### System Information
1288
+
1289
+ **Training Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
1290
+
1291
+ **Configuration:**
1292
+ - Random Seed: {config.random_seed}
1293
+ - Training Set Size: {len(system.train_df)} documents
1294
+ - Validation Set Size: {len(system.val_df)} documents
1295
+ - Test Set Size: {len(system.test_df)} documents
1296
+ - Number of Categories: {len(system.categories)}
1297
+
1298
+ **Categories:**
1299
+ {', '.join(system.categories)}
1300
+
1301
+ **Agent Training Times:**
1302
+ """)
1303
+
1304
+ metrics_df = pd.DataFrame([
1305
+ {
1306
+ 'Agent': 'TF-IDF Agent',
1307
+ 'Training Time (s)': f"{training_results['agent_metrics'][0]['training_time']:.2f}",
1308
+ 'Validation Accuracy': f"{training_results['agent_metrics'][0]['accuracy']:.4f}",
1309
+ 'Validation F1': f"{training_results['agent_metrics'][0]['f1_weighted']:.4f}"
1310
+ },
1311
+ {
1312
+ 'Agent': 'Embedding Agent',
1313
+ 'Training Time (s)': f"{training_results['agent_metrics'][1]['training_time']:.2f}",
1314
+ 'Validation Accuracy': f"{training_results['agent_metrics'][1]['accuracy']:.4f}",
1315
+ 'Validation F1': f"{training_results['agent_metrics'][1]['f1_weighted']:.4f}"
1316
+ },
1317
+ {
1318
+ 'Agent': 'XGBoost Agent',
1319
+ 'Training Time (s)': f"{training_results['agent_metrics'][2]['training_time']:.2f}",
1320
+ 'Validation Accuracy': f"{training_results['agent_metrics'][2]['accuracy']:.4f}",
1321
+ 'Validation F1': f"{training_results['agent_metrics'][2]['f1_weighted']:.4f}"
1322
+ }
1323
+ ])
1324
+
1325
+ gr.DataFrame(value=metrics_df, label="Agent Training Metrics")
1326
+
1327
+ gr.Markdown("""
1328
+ ### Model Limitations and Failure Cases
1329
+
1330
+ **Known Limitations:**
1331
+ 1. **Domain Specificity**: Trained on newsgroup data, may not generalize well to
1332
+ significantly different domains (e.g., legal documents, medical reports)
1333
+ 2. **Short Text**: Performance may degrade on very short documents (< 50 words)
1334
+ 3. **Ambiguous Content**: Documents covering multiple topics may be misclassified
1335
+ 4. **Training Data Bias**: Performance reflects biases present in training data
1336
+ 5. **Language**: Only trained on English text
1337
+
1338
+ **Expected Failure Cases:**
1339
+ - Documents mixing multiple topics from different categories
1340
+ - Highly technical jargon not present in training data
1341
+ - Sarcasm, irony, or implicit meaning
1342
+ - Very long documents (> 10,000 words) may lose context
1343
+ - Non-English text or code-switched content
1344
+
1345
+ **Uncertainty Indicators:**
1346
+ - Confidence < 50%: Prediction is highly uncertain, consider human review
1347
+ - Top 2 predictions very close: Document may belong to multiple categories
1348
+ - Agent votes disagree significantly: Complex or ambiguous document
1349
+
1350
+ ### Ethical Considerations
1351
+
1352
+ This system should be used responsibly:
1353
+ - Not suitable for high-stakes decisions without human oversight
1354
+ - May perpetuate biases present in training data
1355
+ - Should be regularly monitored and updated with new data
1356
+ - Users should verify important predictions
1357
+
1358
+ ### Technical Details
1359
+
1360
+ **Feature Engineering:**
1361
+ - TF-IDF: 5000 features, bigrams, sublinear TF scaling
1362
+ - Embeddings: 384-dimensional sentence-transformers (all-MiniLM-L6-v2)
1363
+ - Metadata: Document length, word count, average word length
1364
+
1365
+ **Model Architectures:**
1366
+ - TF-IDF Agent: Logistic Regression (L2 regularization)
1367
+ - Embedding Agent: 2-layer neural network (384 -> 256 -> 128 -> 20)
1368
+ - XGBoost Agent: 200 estimators, max depth 6, learning rate 0.1
1369
+ - Meta-learner: Logistic Regression on stacked predictions
1370
+
1371
+ **Reproducibility:**
1372
+ All random seeds are set to {config.random_seed} for reproducibility.
1373
+ Training on the same data with same configuration should yield very similar results.
1374
+ """)
1375
+
1376
+ with gr.Tab("About"):
1377
+ gr.Markdown("""
1378
+ ### About This System
1379
+
1380
+ **Project:** Multi-Agent AI Collaboration System for Document Classification
1381
+
1382
+ **Author:** Spencer Purdy
1383
+
1384
+ **Purpose:** Demonstrate genuine multi-model machine learning collaboration
1385
+ for document classification and routing.
1386
+
1387
+ **Real-World Applications:**
1388
+ - Customer support ticket routing
1389
+ - Email categorization
1390
+ - Content moderation
1391
+ - Document management systems
1392
+ - News article classification
1393
+
1394
+ **Dataset:**
1395
+ - 20 Newsgroups dataset
1396
+ - Publicly available via scikit-learn
1397
+ - Approximately 18,000 newsgroup posts
1398
+ - 20 categories covering diverse topics
1399
+ - No personal or sensitive information
1400
+
1401
+ **Technology Stack:**
1402
+ - scikit-learn: Classical ML algorithms and pipelines
1403
+ - PyTorch: Neural network implementation
1404
+ - sentence-transformers: Semantic embeddings
1405
+ - XGBoost: Gradient boosting
1406
+ - Gradio: User interface
1407
+
1408
+ **Development:**
1409
+ - Developed and tested in Google Colab
1410
+ - Can be deployed to Hugging Face Spaces
1411
+ - All dependencies explicitly versioned
1412
+ - Code is documented and follows best practices
1413
+
1414
+ **License:**
1415
+ - Code: MIT License
1416
+ - Dataset: Public domain (20 Newsgroups)
1417
+
1418
+ **Contact:**
1419
+ For questions or issues, please contact Spencer Purdy.
1420
+
1421
+ **Acknowledgments:**
1422
+ - 20 Newsgroups dataset creators
1423
+ - scikit-learn team
1424
+ - Hugging Face for sentence-transformers
1425
+ - Open source ML community
1426
+ """)
1427
+
1428
+ return interface
1429
+
1430
+ # Main execution
1431
+ if __name__ == "__main__":
1432
+ logger.info("=" * 70)
1433
+ logger.info("Multi-Agent AI Collaboration System")
1434
+ logger.info("Author: Spencer Purdy")
1435
+ logger.info("=" * 70)
1436
+ logger.info(f"Random seed: {RANDOM_SEED}")
1437
+ logger.info(f"PyTorch version: {torch.__version__}")
1438
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
1439
+ if torch.cuda.is_available():
1440
+ logger.info(f"CUDA device: {torch.cuda.get_device_name(0)}")
1441
+
1442
+ # Initialize system
1443
+ logger.info("\nInitializing system...")
1444
+ system = MultiAgentSystem(config)
1445
+
1446
+ # Train system
1447
+ logger.info("\nStarting training process...")
1448
+ training_results = system.train_full_system()
1449
+
1450
+ # Create and launch interface
1451
+ logger.info("\nCreating Gradio interface...")
1452
+ interface = create_gradio_interface(system, training_results)
1453
+
1454
+ logger.info("\nLaunching interface...")
1455
+ interface.launch(
1456
+ share=True,
1457
+ server_name="0.0.0.0",
1458
+ server_port=7860,
1459
+ show_error=True
1460
+ )