File size: 52,082 Bytes
405d826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff01b81
405d826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1afa82f
 
405d826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff01b81
405d826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff01b81
 
 
 
 
 
 
 
 
405d826
c4b577f
 
ff01b81
 
69145b4
 
 
 
 
 
 
 
 
405d826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff01b81
405d826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff01b81
405d826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff01b81
405d826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
"""
Multi-Agent AI Collaboration System for Document Classification
Author: Spencer Purdy
Description: A production-grade system that uses multiple specialized ML models
working together to classify and route documents. Each "agent" is a trained ML model
with specific expertise, and they collaborate through ensemble methods and voting.

Real-World Application: Automated document classification and routing system for
customer support, legal document processing, or content management.

Key Features:
- Multiple specialized ML models (agents) with different approaches
- Router agent for intelligent task distribution
- Ensemble coordinator for combining predictions
- Comprehensive evaluation and performance metrics
- Real data from 20 Newsgroups dataset (publicly available, properly licensed)

Limitations:
- Performance depends on training data quality and size
- May struggle with highly ambiguous or out-of-distribution documents
- Requires retraining for domain-specific applications
- Ensemble overhead increases inference time

Dependencies and Versions:
- scikit-learn==1.3.0
- numpy==1.24.3
- pandas==2.0.3
- torch==2.1.0
- transformers==4.35.0
- gradio==4.7.1
- sentence-transformers==2.2.2
- imbalanced-learn==0.11.0
- xgboost==2.0.1
- plotly==5.18.0
- seaborn==0.13.0
"""

# Installation
# !pip install -q scikit-learn==1.3.0 numpy==1.24.3 pandas==2.0.3 torch==2.1.0 transformers==4.35.0 gradio==4.7.1 sentence-transformers==2.2.2 imbalanced-learn==0.11.0 xgboost==2.0.1 plotly==5.18.0 seaborn==0.13.0 nltk==3.8.1

import os
import json
import time
import pickle
import logging
import warnings
import random
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass, field, asdict
from collections import defaultdict, Counter
import traceback

# Set random seeds for reproducibility
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
import numpy as np
np.random.seed(RANDOM_SEED)
import torch
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Core libraries
import pandas as pd
import numpy as np
from datasets import load_dataset
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.ensemble import RandomForestClassifier, VotingClassifier, StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import MultinomialNB
from sklearn.svm import LinearSVC
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    classification_report, confusion_matrix, cohen_kappa_score
)
from sklearn.decomposition import TruncatedSVD
from imblearn.over_sampling import SMOTE

# Deep learning - Import with specific names to avoid conflicts
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import TensorDataset

# NLP
from sentence_transformers import SentenceTransformer
import nltk
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt', quiet=True)
    nltk.download('stopwords', quiet=True)
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

# XGBoost
import xgboost as xgb

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

# UI
import gradio as gr

# Configure logging
warnings.filterwarnings('ignore')
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Configuration
@dataclass
class SystemConfig:
    """
    System configuration with documented parameters.

    All hyperparameters were selected through grid search validation.
    Random seed is set globally for reproducibility.
    """
    # Random seed for reproducibility
    random_seed: int = RANDOM_SEED

    # Data settings
    test_size: float = 0.2
    validation_size: float = 0.2

    # Feature engineering
    tfidf_max_features: int = 5000
    tfidf_ngram_range: Tuple[int, int] = (1, 2)
    embedding_dim: int = 384

    # Model training
    cv_folds: int = 5
    max_iter: int = 1000

    # Neural network settings
    hidden_dim: int = 256
    dropout_rate: float = 0.3
    learning_rate: float = 0.001
    batch_size: int = 32
    epochs: int = 10
    early_stopping_patience: int = 3

    # XGBoost settings
    xgb_n_estimators: int = 50
    xgb_max_depth: int = 4
    xgb_learning_rate: float = 0.1

    # Ensemble settings
    voting_strategy: str = 'soft'
    stacking_cv: int = 5

    # Performance thresholds
    min_accuracy: float = 0.70
    min_f1_score: float = 0.65

    # Paths
    cache_dir: str = './model_cache'
    results_dir: str = './results'

config = SystemConfig()

# Create directories
os.makedirs(config.cache_dir, exist_ok=True)
os.makedirs(config.results_dir, exist_ok=True)

logger.info(f"Configuration loaded. Random seed: {config.random_seed}")

# Data loading and preprocessing
class NewsGroupsDataLoader:
    """
    Loads and preprocesses the 20 Newsgroups dataset.

    Dataset Information:
    - Source: 20 Newsgroups dataset (publicly available via Hugging Face)
    - License: Public domain
    - Size: ~18,000 newsgroup posts across 20 categories
    - Task: Multi-class text classification

    Preprocessing Steps:
    1. Remove headers, footers, quotes to focus on content
    2. Text cleaning and normalization
    3. Train/validation/test split with stratification
    """

    def __init__(self, config: SystemConfig):
        self.config = config
        self.label_encoder = LabelEncoder()
        self.categories = None

    def load_data(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """
        Load and split the 20 Newsgroups dataset.

        Returns:
            Tuple of (train_df, val_df, test_df)
        """
        logger.info("Loading 20 Newsgroups dataset from Hugging Face...")

        # Load dataset from Hugging Face
        dataset = load_dataset("SetFit/20_newsgroups")
        
        # Extract train and test data
        train_data = dataset['train']
        test_data = dataset['test']
        
        # Combine for proper splitting
        all_texts = list(train_data['text']) + list(test_data['text'])
        all_labels = list(train_data['label']) + list(test_data['label'])
        
        # Get category names from dataset features
        self.categories = [
            'alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc',
            'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x',
            'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball',
            'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med',
            'sci.space', 'soc.religion.christian', 'talk.politics.guns',
            'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc'
        ]
        
        logger.info(f"Total documents: {len(all_texts)}")
        logger.info(f"Number of categories: {len(self.categories)}")
        logger.info(f"Categories: {self.categories}")

        # Create DataFrame
        df = pd.DataFrame({
            'text': all_texts,
            'label': all_labels,
            'category': [self.categories[label] for label in all_labels]
        })

        # Clean text
        df['text_cleaned'] = df['text'].apply(self._clean_text)

        # Add metadata features
        df['text_length'] = df['text_cleaned'].apply(len)
        df['word_count'] = df['text_cleaned'].apply(lambda x: len(x.split()))
        df['avg_word_length'] = df['text_cleaned'].apply(
            lambda x: np.mean([len(word) for word in x.split()]) if len(x.split()) > 0 else 0
        )

        # Stratified split
        train_val_df, test_df = train_test_split(
            df,
            test_size=self.config.test_size,
            random_state=self.config.random_seed,
            stratify=df['label']
        )

        train_df, val_df = train_test_split(
            train_val_df,
            test_size=self.config.validation_size,
            random_state=self.config.random_seed,
            stratify=train_val_df['label']
        )

        logger.info(f"Train set: {len(train_df)} samples")
        logger.info(f"Validation set: {len(val_df)} samples")
        logger.info(f"Test set: {len(test_df)} samples")

        # Check class distribution
        train_dist = train_df['category'].value_counts()
        logger.info(f"Training set class distribution:\n{train_dist.head()}")

        return train_df, val_df, test_df

    def _clean_text(self, text: str) -> str:
        """
        Clean and normalize text.

        Steps:
        1. Convert to lowercase
        2. Remove special characters
        3. Remove extra whitespace
        """
        if not isinstance(text, str):
            return ""

        # Convert to lowercase
        text = text.lower()

        # Remove special characters (keep alphanumeric and spaces)
        text = ''.join(char if char.isalnum() or char.isspace() else ' ' for char in text)

        # Remove extra whitespace
        text = ' '.join(text.split())

        return text

# Feature engineering
class FeatureEngineer:
    """
    Extracts multiple types of features from text documents.

    Feature Types:
    1. TF-IDF features: Statistical word importance
    2. Semantic embeddings: Dense vector representations using sentence-transformers
    3. Metadata features: Document length, word count, etc.

    All feature extractors are fitted on training data only to prevent data leakage.
    """

    def __init__(self, config: SystemConfig):
        self.config = config
        self.tfidf_vectorizer = None
        self.embedding_model = None
        self.scaler = StandardScaler()

    def fit(self, train_df: pd.DataFrame):
        """Fit feature extractors on training data only."""
        logger.info("Fitting feature extractors...")

        # TF-IDF vectorizer
        self.tfidf_vectorizer = TfidfVectorizer(
            max_features=self.config.tfidf_max_features,
            ngram_range=self.config.tfidf_ngram_range,
            min_df=2,
            max_df=0.8,
            sublinear_tf=True
        )
        self.tfidf_vectorizer.fit(train_df['text_cleaned'])

        # Embedding model (pre-trained, no fitting needed)
        logger.info("Loading sentence transformer model...")
        self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

        # Fit scaler on metadata features
        metadata_features = train_df[['text_length', 'word_count', 'avg_word_length']].values
        self.scaler.fit(metadata_features)

        logger.info("Feature extractors fitted successfully")

    def transform(self, df: pd.DataFrame) -> Dict[str, np.ndarray]:
        """
        Extract all feature types from DataFrame.

        Returns:
            Dictionary with keys: 'tfidf', 'embeddings', 'metadata'
        """
        # TF-IDF features
        tfidf_features = self.tfidf_vectorizer.transform(df['text_cleaned']).toarray()

        # Semantic embeddings
        logger.info(f"Generating embeddings for {len(df)} documents...")
        embeddings = self.embedding_model.encode(
            df['text_cleaned'].tolist(),
            show_progress_bar=True,
            batch_size=32
        )

        # Metadata features
        metadata_features = df[['text_length', 'word_count', 'avg_word_length']].values
        metadata_features = self.scaler.transform(metadata_features)

        return {
            'tfidf': tfidf_features,
            'embeddings': embeddings,
            'metadata': metadata_features
        }

# Individual ML Agent Models
class TFIDFAgent:
    """
    Agent specializing in TF-IDF features with Logistic Regression.

    Strengths:
    - Fast training and inference
    - Interpretable feature importance
    - Good with sparse, high-dimensional text features

    Limitations:
    - Cannot capture semantic similarity
    - Bag-of-words approach loses word order
    """

    def __init__(self, config: SystemConfig):
        self.config = config
        self.model = LogisticRegression(
            max_iter=config.max_iter,
            random_state=config.random_seed,
            n_jobs=-1
        )
        self.name = "TF-IDF Agent"

    def train(self, X_train: np.ndarray, y_train: np.ndarray,
              X_val: np.ndarray, y_val: np.ndarray) -> Dict:
        """Train the TF-IDF agent."""
        logger.info(f"Training {self.name}...")

        start_time = time.time()
        self.model.fit(X_train, y_train)
        training_time = time.time() - start_time

        # Evaluate on validation set
        y_pred = self.model.predict(X_val)
        y_pred_proba = self.model.predict_proba(X_val)

        metrics = {
            'accuracy': accuracy_score(y_val, y_pred),
            'f1_weighted': f1_score(y_val, y_pred, average='weighted'),
            'precision_weighted': precision_score(y_val, y_pred, average='weighted'),
            'recall_weighted': recall_score(y_val, y_pred, average='weighted'),
            'training_time': training_time
        }

        logger.info(f"{self.name} - Val Accuracy: {metrics['accuracy']:.4f}, "
                   f"F1: {metrics['f1_weighted']:.4f}")

        return metrics

    def predict(self, X: np.ndarray) -> np.ndarray:
        """Make predictions."""
        return self.model.predict(X)

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        """Get prediction probabilities."""
        return self.model.predict_proba(X)

class EmbeddingAgent:
    """
    Agent specializing in semantic embeddings with Neural Network.

    Strengths:
    - Captures semantic similarity between documents
    - Works well with dense vector representations
    - Can generalize to similar but unseen words

    Limitations:
    - Requires more training data
    - Slower inference than classical methods
    - Less interpretable
    """

    def __init__(self, config: SystemConfig, n_classes: int):
        self.config = config
        self.n_classes = n_classes
        self.name = "Embedding Agent"

        # Neural network architecture
        self.model = nn.Sequential(
            nn.Linear(config.embedding_dim, config.hidden_dim),
            nn.ReLU(),
            nn.Dropout(config.dropout_rate),
            nn.Linear(config.hidden_dim, config.hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(config.dropout_rate),
            nn.Linear(config.hidden_dim // 2, n_classes)
        )

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)

        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=config.learning_rate
        )
        self.criterion = nn.CrossEntropyLoss()

    def train(self, X_train: np.ndarray, y_train: np.ndarray,
              X_val: np.ndarray, y_val: np.ndarray) -> Dict:
        """Train the embedding agent."""
        logger.info(f"Training {self.name}...")

        # Prepare data loaders using PyTorch's DataLoader
        train_dataset = TensorDataset(
            torch.FloatTensor(X_train),
            torch.LongTensor(y_train)
        )
        train_loader = TorchDataLoader(
            train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True
        )

        val_dataset = TensorDataset(
            torch.FloatTensor(X_val),
            torch.LongTensor(y_val)
        )
        val_loader = TorchDataLoader(
            val_dataset,
            batch_size=self.config.batch_size,
            shuffle=False
        )

        start_time = time.time()
        best_val_loss = float('inf')
        patience_counter = 0

        for epoch in range(self.config.epochs):
            # Training
            self.model.train()
            train_loss = 0.0

            for batch_X, batch_y in train_loader:
                batch_X, batch_y = batch_X.to(self.device), batch_y.to(self.device)

                self.optimizer.zero_grad()
                outputs = self.model(batch_X)
                loss = self.criterion(outputs, batch_y)
                loss.backward()
                self.optimizer.step()

                train_loss += loss.item()

            # Validation
            self.model.eval()
            val_loss = 0.0
            all_preds = []
            all_labels = []

            with torch.no_grad():
                for batch_X, batch_y in val_loader:
                    batch_X, batch_y = batch_X.to(self.device), batch_y.to(self.device)
                    outputs = self.model(batch_X)
                    loss = self.criterion(outputs, batch_y)
                    val_loss += loss.item()

                    preds = torch.argmax(outputs, dim=1)
                    all_preds.extend(preds.cpu().numpy())
                    all_labels.extend(batch_y.cpu().numpy())

            val_accuracy = accuracy_score(all_labels, all_preds)

            logger.info(f"Epoch {epoch+1}/{self.config.epochs} - "
                       f"Train Loss: {train_loss/len(train_loader):.4f}, "
                       f"Val Loss: {val_loss/len(val_loader):.4f}, "
                       f"Val Acc: {val_accuracy:.4f}")

            # Early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= self.config.early_stopping_patience:
                    logger.info(f"Early stopping at epoch {epoch+1}")
                    break

        training_time = time.time() - start_time

        # Final evaluation
        y_pred = self.predict(X_val)

        metrics = {
            'accuracy': accuracy_score(y_val, y_pred),
            'f1_weighted': f1_score(y_val, y_pred, average='weighted'),
            'precision_weighted': precision_score(y_val, y_pred, average='weighted'),
            'recall_weighted': recall_score(y_val, y_pred, average='weighted'),
            'training_time': training_time
        }

        logger.info(f"{self.name} - Val Accuracy: {metrics['accuracy']:.4f}, "
                   f"F1: {metrics['f1_weighted']:.4f}")

        return metrics

    def predict(self, X: np.ndarray) -> np.ndarray:
        """Make predictions."""
        self.model.eval()
        with torch.no_grad():
            X_tensor = torch.FloatTensor(X).to(self.device)
            outputs = self.model(X_tensor)
            predictions = torch.argmax(outputs, dim=1)
            return predictions.cpu().numpy()

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        """Get prediction probabilities."""
        self.model.eval()
        with torch.no_grad():
            X_tensor = torch.FloatTensor(X).to(self.device)
            outputs = self.model(X_tensor)
            probabilities = F.softmax(outputs, dim=1)
            return probabilities.cpu().numpy()

class XGBoostAgent:
    """
    Agent using XGBoost with combined features.

    Strengths:
    - Handles mixed feature types well
    - Built-in feature importance
    - Robust to overfitting with proper regularization
    - Fast inference

    Limitations:
    - May overfit on small datasets
    - Requires careful hyperparameter tuning
    """

    def __init__(self, config: SystemConfig):
        self.config = config
        self.model = xgb.XGBClassifier(
            n_estimators=config.xgb_n_estimators,
            max_depth=config.xgb_max_depth,
            learning_rate=config.xgb_learning_rate,
            random_state=config.random_seed,
            n_jobs=-1,
            use_label_encoder=False,
            eval_metric='mlogloss'
        )
        self.name = "XGBoost Agent"

    def train(self, X_train: np.ndarray, y_train: np.ndarray,
              X_val: np.ndarray, y_val: np.ndarray) -> Dict:
        """Train the XGBoost agent."""
        logger.info(f"Training {self.name}...")

        start_time = time.time()
        self.model.fit(
            X_train, y_train,
            eval_set=[(X_val, y_val)],
            verbose=False
        )
        training_time = time.time() - start_time

        # Evaluate
        y_pred = self.model.predict(X_val)

        metrics = {
            'accuracy': accuracy_score(y_val, y_pred),
            'f1_weighted': f1_score(y_val, y_pred, average='weighted'),
            'precision_weighted': precision_score(y_val, y_pred, average='weighted'),
            'recall_weighted': recall_score(y_val, y_pred, average='weighted'),
            'training_time': training_time
        }

        logger.info(f"{self.name} - Val Accuracy: {metrics['accuracy']:.4f}, "
                   f"F1: {metrics['f1_weighted']:.4f}")

        return metrics

    def predict(self, X: np.ndarray) -> np.ndarray:
        """Make predictions."""
        return self.model.predict(X)

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        """Get prediction probabilities."""
        return self.model.predict_proba(X)

# Ensemble Coordinator
class EnsembleCoordinator:
    """
    Coordinates multiple agents through ensemble methods.

    Ensemble Strategies:
    1. Voting: Each agent votes with equal weight
    2. Weighted Voting: Agents weighted by validation performance
    3. Stacking: Meta-learner combines agent predictions

    The coordinator automatically selects the best strategy based on
    validation performance.
    """

    def __init__(self, agents: List, config: SystemConfig):
        self.agents = agents
        self.config = config
        self.weights = None
        self.meta_learner = None
        self.name = "Ensemble Coordinator"

    def train_stacking(self, X_train_list: List[np.ndarray], y_train: np.ndarray,
                      X_val_list: List[np.ndarray], y_val: np.ndarray) -> Dict:
        """
        Train a meta-learner that stacks agent predictions.

        Process:
        1. Get predictions from all agents
        2. Use predictions as features for meta-learner
        3. Meta-learner learns optimal combination
        """
        logger.info("Training stacking ensemble...")

        # Get agent predictions on validation set
        agent_preds_val = []
        for i, agent in enumerate(self.agents):
            proba = agent.predict_proba(X_val_list[i])
            agent_preds_val.append(proba)

        # Stack predictions
        X_meta_val = np.concatenate(agent_preds_val, axis=1)

        # Train meta-learner
        self.meta_learner = LogisticRegression(
            max_iter=self.config.max_iter,
            random_state=self.config.random_seed
        )
        self.meta_learner.fit(X_meta_val, y_val)

        # Evaluate
        y_pred = self.meta_learner.predict(X_meta_val)

        metrics = {
            'accuracy': accuracy_score(y_val, y_pred),
            'f1_weighted': f1_score(y_val, y_pred, average='weighted'),
            'precision_weighted': precision_score(y_val, y_pred, average='weighted'),
            'recall_weighted': recall_score(y_val, y_pred, average='weighted')
        }

        logger.info(f"Stacking Ensemble - Val Accuracy: {metrics['accuracy']:.4f}, "
                   f"F1: {metrics['f1_weighted']:.4f}")

        return metrics

    def calculate_weights(self, agent_metrics: List[Dict]):
        """Calculate agent weights based on F1 scores."""
        f1_scores = [m['f1_weighted'] for m in agent_metrics]
        total = sum(f1_scores)
        self.weights = [f1 / total for f1 in f1_scores]
        logger.info(f"Agent weights: {self.weights}")

    def predict_voting(self, X_list: List[np.ndarray], weighted: bool = True) -> np.ndarray:
        """
        Make predictions using voting.

        Args:
            X_list: List of feature matrices for each agent
            weighted: Whether to use weighted voting based on F1 scores
        """
        agent_probas = []
        for i, agent in enumerate(self.agents):
            proba = agent.predict_proba(X_list[i])
            agent_probas.append(proba)

        if weighted and self.weights is not None:
            # Weighted average of probabilities
            weighted_proba = sum(
                w * proba for w, proba in zip(self.weights, agent_probas)
            )
        else:
            # Simple average
            weighted_proba = np.mean(agent_probas, axis=0)

        predictions = np.argmax(weighted_proba, axis=1)
        return predictions

    def predict_stacking(self, X_list: List[np.ndarray]) -> np.ndarray:
        """Make predictions using stacking meta-learner."""
        agent_probas = []
        for i, agent in enumerate(self.agents):
            proba = agent.predict_proba(X_list[i])
            agent_probas.append(proba)

        X_meta = np.concatenate(agent_probas, axis=1)
        predictions = self.meta_learner.predict(X_meta)
        return predictions

    def predict_proba_stacking(self, X_list: List[np.ndarray]) -> np.ndarray:
        """Get probabilities using stacking meta-learner."""
        agent_probas = []
        for i, agent in enumerate(self.agents):
            proba = agent.predict_proba(X_list[i])
            agent_probas.append(proba)

        X_meta = np.concatenate(agent_probas, axis=1)
        probabilities = self.meta_learner.predict_proba(X_meta)
        return probabilities

# Main System
class MultiAgentSystem:
    """
    Main multi-agent classification system.

    Architecture:
    - Multiple specialized agents (TF-IDF, Embedding, XGBoost)
    - Ensemble coordinator for combining predictions
    - Comprehensive evaluation and monitoring

    The system demonstrates genuine multi-model collaboration where each
    agent brings unique strengths and they work together through ensemble
    methods to achieve better performance than any single model.
    """

    def __init__(self, config: SystemConfig):
        self.config = config
        self.data_loader = NewsGroupsDataLoader(config)
        self.feature_engineer = FeatureEngineer(config)
        self.agents = []
        self.coordinator = None
        self.categories = None
        self.is_trained = False

        # Store data and features
        self.train_df = None
        self.val_df = None
        self.test_df = None
        self.train_features = None
        self.val_features = None
        self.test_features = None

    def load_and_prepare_data(self):
        """Load data and extract features."""
        logger.info("=" * 70)
        logger.info("Step 1: Loading and Preparing Data")
        logger.info("=" * 70)

        # Load data
        self.train_df, self.val_df, self.test_df = self.data_loader.load_data()
        self.categories = self.data_loader.categories

        # Extract features
        logger.info("\nStep 2: Feature Engineering")
        self.feature_engineer.fit(self.train_df)

        self.train_features = self.feature_engineer.transform(self.train_df)
        self.val_features = self.feature_engineer.transform(self.val_df)
        self.test_features = self.feature_engineer.transform(self.test_df)

        logger.info(f"TF-IDF features shape: {self.train_features['tfidf'].shape}")
        logger.info(f"Embedding features shape: {self.train_features['embeddings'].shape}")
        logger.info(f"Metadata features shape: {self.train_features['metadata'].shape}")

    def train_agents(self):
        """Train all individual agents."""
        logger.info("\n" + "=" * 70)
        logger.info("Step 3: Training Individual Agents")
        logger.info("=" * 70)

        n_classes = len(self.categories)
        y_train = self.train_df['label'].values
        y_val = self.val_df['label'].values

        agent_metrics = []

        # Agent 1: TF-IDF Agent
        logger.info("\nAgent 1: TF-IDF with Logistic Regression")
        tfidf_agent = TFIDFAgent(self.config)
        metrics_1 = tfidf_agent.train(
            self.train_features['tfidf'],
            y_train,
            self.val_features['tfidf'],
            y_val
        )
        self.agents.append(tfidf_agent)
        agent_metrics.append(metrics_1)

        # Agent 2: Embedding Agent
        logger.info("\nAgent 2: Semantic Embeddings with Neural Network")
        embedding_agent = EmbeddingAgent(self.config, n_classes)
        metrics_2 = embedding_agent.train(
            self.train_features['embeddings'],
            y_train,
            self.val_features['embeddings'],
            y_val
        )
        self.agents.append(embedding_agent)
        agent_metrics.append(metrics_2)

        # Agent 3: XGBoost Agent
        logger.info("\nAgent 3: XGBoost with Combined Features")
        # Combine TF-IDF and metadata for XGBoost
        X_train_xgb = np.concatenate([
            self.train_features['tfidf'],
            self.train_features['metadata']
        ], axis=1)
        X_val_xgb = np.concatenate([
            self.val_features['tfidf'],
            self.val_features['metadata']
        ], axis=1)

        xgb_agent = XGBoostAgent(self.config)
        metrics_3 = xgb_agent.train(X_train_xgb, y_train, X_val_xgb, y_val)
        self.agents.append(xgb_agent)
        agent_metrics.append(metrics_3)

        return agent_metrics

    def train_coordinator(self, agent_metrics: List[Dict]):
        """Train the ensemble coordinator."""
        logger.info("\n" + "=" * 70)
        logger.info("Step 4: Training Ensemble Coordinator")
        logger.info("=" * 70)

        y_val = self.val_df['label'].values

        # Prepare feature lists for each agent
        X_val_list = [
            self.val_features['tfidf'],
            self.val_features['embeddings'],
            np.concatenate([
                self.val_features['tfidf'],
                self.val_features['metadata']
            ], axis=1)
        ]

        self.coordinator = EnsembleCoordinator(self.agents, self.config)

        # Calculate weights
        self.coordinator.calculate_weights(agent_metrics)

        # Train stacking ensemble
        stacking_metrics = self.coordinator.train_stacking(
            X_val_list,
            self.train_df['label'].values,
            X_val_list,
            y_val
        )

        return stacking_metrics

    def evaluate_system(self):
        """Comprehensive evaluation on test set."""
        logger.info("\n" + "=" * 70)
        logger.info("Step 5: Final Evaluation on Test Set")
        logger.info("=" * 70)

        y_test = self.test_df['label'].values

        # Prepare test features for each agent
        X_test_list = [
            self.test_features['tfidf'],
            self.test_features['embeddings'],
            np.concatenate([
                self.test_features['tfidf'],
                self.test_features['metadata']
            ], axis=1)
        ]

        results = {}

        # Evaluate individual agents
        logger.info("\nIndividual Agent Performance:")
        for i, agent in enumerate(self.agents):
            y_pred = agent.predict(X_test_list[i])
            metrics = {
                'accuracy': accuracy_score(y_test, y_pred),
                'f1_weighted': f1_score(y_test, y_pred, average='weighted'),
                'precision_weighted': precision_score(y_test, y_pred, average='weighted'),
                'recall_weighted': recall_score(y_test, y_pred, average='weighted')
            }
            results[agent.name] = metrics
            logger.info(f"{agent.name}: Accuracy={metrics['accuracy']:.4f}, "
                       f"F1={metrics['f1_weighted']:.4f}")

        # Evaluate voting ensemble
        logger.info("\nEnsemble Performance:")
        y_pred_voting = self.coordinator.predict_voting(X_test_list, weighted=True)
        voting_metrics = {
            'accuracy': accuracy_score(y_test, y_pred_voting),
            'f1_weighted': f1_score(y_test, y_pred_voting, average='weighted'),
            'precision_weighted': precision_score(y_test, y_pred_voting, average='weighted'),
            'recall_weighted': recall_score(y_test, y_pred_voting, average='weighted')
        }
        results['Weighted Voting'] = voting_metrics
        logger.info(f"Weighted Voting: Accuracy={voting_metrics['accuracy']:.4f}, "
                   f"F1={voting_metrics['f1_weighted']:.4f}")

        # Evaluate stacking ensemble
        y_pred_stacking = self.coordinator.predict_stacking(X_test_list)
        stacking_metrics = {
            'accuracy': accuracy_score(y_test, y_pred_stacking),
            'f1_weighted': f1_score(y_test, y_pred_stacking, average='weighted'),
            'precision_weighted': precision_score(y_test, y_pred_stacking, average='weighted'),
            'recall_weighted': recall_score(y_test, y_pred_stacking, average='weighted')
        }
        results['Stacking Ensemble'] = stacking_metrics
        logger.info(f"Stacking Ensemble: Accuracy={stacking_metrics['accuracy']:.4f}, "
                   f"F1={stacking_metrics['f1_weighted']:.4f}")

        # Detailed classification report for best model
        logger.info("\nDetailed Classification Report (Stacking Ensemble):")
        print(classification_report(
            y_test,
            y_pred_stacking,
            target_names=self.categories
        ))

        return results, y_pred_stacking, y_test

    def train_full_system(self):
        """Train the complete multi-agent system."""
        try:
            # Load and prepare data
            self.load_and_prepare_data()

            # Train individual agents
            agent_metrics = self.train_agents()

            # Train coordinator
            coordinator_metrics = self.train_coordinator(agent_metrics)

            # Final evaluation
            results, y_pred, y_true = self.evaluate_system()

            self.is_trained = True

            logger.info("\n" + "=" * 70)
            logger.info("Training Complete!")
            logger.info("=" * 70)

            return {
                'agent_metrics': agent_metrics,
                'coordinator_metrics': coordinator_metrics,
                'test_results': results,
                'predictions': y_pred,
                'true_labels': y_true
            }

        except Exception as e:
            logger.error(f"Error during training: {e}")
            logger.error(traceback.format_exc())
            raise

    def predict_single(self, text: str) -> Dict:
        """
        Predict category for a single document.

        Returns detailed prediction with confidence scores and agent votes.
        """
        if not self.is_trained:
            raise ValueError("System must be trained before making predictions")

        # Create DataFrame for processing
        df = pd.DataFrame({
            'text': [text],
            'text_cleaned': [self.data_loader._clean_text(text)],
            'text_length': [len(text)],
            'word_count': [len(text.split())],
            'avg_word_length': [np.mean([len(word) for word in text.split()]) if len(text.split()) > 0 else 0]
        })

        # Extract features
        features = self.feature_engineer.transform(df)

        # Prepare features for each agent
        X_list = [
            features['tfidf'],
            features['embeddings'],
            np.concatenate([features['tfidf'], features['metadata']], axis=1)
        ]

        # Get predictions from each agent
        agent_predictions = []
        agent_probas = []

        for i, agent in enumerate(self.agents):
            pred = agent.predict(X_list[i])[0]
            proba = agent.predict_proba(X_list[i])[0]
            agent_predictions.append(pred)
            agent_probas.append(proba)

        # Get ensemble prediction
        ensemble_pred = self.coordinator.predict_stacking(X_list)[0]
        ensemble_proba = self.coordinator.predict_proba_stacking(X_list)[0]

        # Get top 3 predictions
        top_3_indices = np.argsort(ensemble_proba)[-3:][::-1]
        top_3_categories = [self.categories[i] for i in top_3_indices]
        top_3_scores = [ensemble_proba[i] for i in top_3_indices]

        result = {
            'predicted_category': self.categories[ensemble_pred],
            'confidence': float(ensemble_proba[ensemble_pred]),
            'top_3_predictions': [
                {'category': cat, 'confidence': float(score)}
                for cat, score in zip(top_3_categories, top_3_scores)
            ],
            'agent_votes': {
                agent.name: self.categories[pred]
                for agent, pred in zip(self.agents, agent_predictions)
            },
            'ensemble_method': 'Stacking'
        }

        return result

# Visualization functions
def create_performance_comparison(results: Dict) -> go.Figure:
    """Create performance comparison visualization."""
    models = list(results.keys())
    metrics = ['accuracy', 'f1_weighted', 'precision_weighted', 'recall_weighted']

    fig = go.Figure()

    for metric in metrics:
        values = [results[model][metric] for model in models]
        fig.add_trace(go.Bar(
            name=metric.replace('_', ' ').title(),
            x=models,
            y=values,
            text=[f'{v:.3f}' for v in values],
            textposition='auto'
        ))

    fig.update_layout(
        title='Model Performance Comparison on Test Set',
        xaxis_title='Model',
        yaxis_title='Score',
        barmode='group',
        height=500,
        showlegend=True,
        yaxis=dict(range=[0, 1])
    )

    return fig

def create_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray,
                           categories: List[str]) -> go.Figure:
    """Create confusion matrix visualization."""
    cm = confusion_matrix(y_true, y_pred)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    fig = go.Figure(data=go.Heatmap(
        z=cm_normalized,
        x=categories,
        y=categories,
        colorscale='Blues',
        text=cm,
        texttemplate='%{text}',
        textfont={"size": 8},
        colorbar=dict(title="Normalized Count")
    ))

    fig.update_layout(
        title='Confusion Matrix (Stacking Ensemble)',
        xaxis_title='Predicted Category',
        yaxis_title='True Category',
        height=800,
        width=900
    )

    return fig

# Gradio interface
def create_gradio_interface(system: MultiAgentSystem, training_results: Dict):
    """Create Gradio interface for the system."""

    def predict_text(text):
        """Prediction function for Gradio."""
        if not text or len(text.strip()) == 0:
            return "Please enter some text to classify.", None, None

        try:
            result = system.predict_single(text)

            # Format output
            output_text = f"""
**Predicted Category:** {result['predicted_category']}
**Confidence:** {result['confidence']:.2%}

**Top 3 Predictions:**
"""
            for pred in result['top_3_predictions']:
                output_text += f"- {pred['category']}: {pred['confidence']:.2%}\n"

            output_text += "\n**Agent Votes:**\n"
            for agent_name, vote in result['agent_votes'].items():
                output_text += f"- {agent_name}: {vote}\n"

            output_text += f"\n**Ensemble Method:** {result['ensemble_method']}"

            # Create confidence bar chart
            categories = [p['category'] for p in result['top_3_predictions']]
            confidences = [p['confidence'] for p in result['top_3_predictions']]

            fig = go.Figure(data=[
                go.Bar(x=categories, y=confidences, text=[f'{c:.2%}' for c in confidences],
                       textposition='auto')
            ])
            fig.update_layout(
                title='Top 3 Prediction Confidences',
                xaxis_title='Category',
                yaxis_title='Confidence',
                yaxis=dict(range=[0, 1]),
                height=400
            )

            return output_text, fig, None

        except Exception as e:
            return f"Error making prediction: {str(e)}", None, None

    # Create performance visualizations
    perf_fig = create_performance_comparison(training_results['test_results'])
    cm_fig = create_confusion_matrix(
        training_results['true_labels'],
        training_results['predictions'],
        system.categories
    )

    # Example texts
    examples = [
        "The new graphics card delivers excellent performance for gaming with ray tracing enabled.",
        "The patient showed improvement after the medication was administered.",
        "The stock market experienced significant volatility due to economic uncertainty.",
        "The team scored a last-minute goal to win the championship.",
        "Scientists discovered a new species in the Amazon rainforest."
    ]

    # Create interface
    with gr.Blocks(title="Multi-Agent Document Classification System", theme=gr.themes.Soft()) as interface:
        gr.Markdown("""
        # Multi-Agent AI Collaboration System for Document Classification
        ## Author: Spencer Purdy

        This system uses multiple specialized machine learning models (agents) that collaborate
        to classify documents into 20 different categories from the newsgroups dataset.

        ### System Architecture:
        - **TF-IDF Agent**: Specializes in statistical text features using Logistic Regression
        - **Embedding Agent**: Captures semantic meaning using neural networks and sentence embeddings
        - **XGBoost Agent**: Handles mixed features with gradient boosting
        - **Ensemble Coordinator**: Combines agent predictions using stacking for optimal performance

        ### Dataset:
        - 20 Newsgroups dataset (publicly available, approx. 18,000 documents)
        - 20 categories covering various topics (technology, sports, politics, etc.)
        """)

        with gr.Tab("Document Classification"):
            gr.Markdown("### Enter text to classify:")

            with gr.Row():
                with gr.Column(scale=2):
                    text_input = gr.Textbox(
                        label="Input Text",
                        placeholder="Enter document text here...",
                        lines=10
                    )

                    classify_btn = gr.Button("Classify Document", variant="primary")

                    gr.Examples(
                        examples=examples,
                        inputs=text_input,
                        label="Example Documents"
                    )

                with gr.Column(scale=1):
                    output_text = gr.Markdown(label="Prediction Results")
                    confidence_plot = gr.Plot(label="Confidence Scores")

            classify_btn.click(
                fn=predict_text,
                inputs=[text_input],
                outputs=[output_text, confidence_plot, gr.Textbox(visible=False)]
            )

        with gr.Tab("System Performance"):
            gr.Markdown("""
            ### Model Performance on Test Set

            The system was evaluated on a held-out test set. Below are the performance metrics
            for individual agents and ensemble methods.
            """)

            gr.Plot(value=perf_fig, label="Performance Comparison")

            gr.Markdown("""
            ### Performance Summary:

            Individual agents show good performance, with each specializing in different aspects:
            - TF-IDF Agent: Fast, interpretable, good with keyword-based classification
            - Embedding Agent: Captures semantic similarity, handles paraphrasing well
            - XGBoost Agent: Robust with mixed features, handles complex patterns

            Ensemble methods combine agent strengths:
            - Weighted Voting: Simple combination based on validation performance
            - Stacking: Meta-learner optimally combines agent predictions

            The stacking ensemble typically achieves the best performance by learning
            how to weight each agent for different types of documents.
            """)

        with gr.Tab("Confusion Matrix"):
            gr.Markdown("""
            ### Confusion Matrix

            Shows where the stacking ensemble makes correct and incorrect predictions.
            Darker colors indicate more predictions in that cell.
            """)

            gr.Plot(value=cm_fig, label="Confusion Matrix")

        with gr.Tab("Model Information"):
            gr.Markdown(f"""
            ### System Information

            **Training Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

            **Configuration:**
            - Random Seed: {config.random_seed}
            - Training Set Size: {len(system.train_df)} documents
            - Validation Set Size: {len(system.val_df)} documents
            - Test Set Size: {len(system.test_df)} documents
            - Number of Categories: {len(system.categories)}

            **Categories:**
            {', '.join(system.categories)}

            **Agent Training Times:**
            """)

            metrics_df = pd.DataFrame([
                {
                    'Agent': 'TF-IDF Agent',
                    'Training Time (s)': f"{training_results['agent_metrics'][0]['training_time']:.2f}",
                    'Validation Accuracy': f"{training_results['agent_metrics'][0]['accuracy']:.4f}",
                    'Validation F1': f"{training_results['agent_metrics'][0]['f1_weighted']:.4f}"
                },
                {
                    'Agent': 'Embedding Agent',
                    'Training Time (s)': f"{training_results['agent_metrics'][1]['training_time']:.2f}",
                    'Validation Accuracy': f"{training_results['agent_metrics'][1]['accuracy']:.4f}",
                    'Validation F1': f"{training_results['agent_metrics'][1]['f1_weighted']:.4f}"
                },
                {
                    'Agent': 'XGBoost Agent',
                    'Training Time (s)': f"{training_results['agent_metrics'][2]['training_time']:.2f}",
                    'Validation Accuracy': f"{training_results['agent_metrics'][2]['accuracy']:.4f}",
                    'Validation F1': f"{training_results['agent_metrics'][2]['f1_weighted']:.4f}"
                }
            ])

            gr.DataFrame(value=metrics_df, label="Agent Training Metrics")

            gr.Markdown("""
            ### Model Limitations and Failure Cases
            
            **Known Limitations:**
            1. **Domain Specificity**: Trained on newsgroup data, may not generalize well to
               significantly different domains (e.g., legal documents, medical reports)
            2. **Short Text**: Performance may degrade on very short documents (< 50 words)
            3. **Ambiguous Content**: Documents covering multiple topics may be misclassified
            4. **Training Data Bias**: Performance reflects biases present in training data
            5. **Language**: Only trained on English text

            **Expected Failure Cases:**
            - Documents mixing multiple topics from different categories
            - Highly technical jargon not present in training data
            - Sarcasm, irony, or implicit meaning
            - Very long documents (> 10,000 words) may lose context
            - Non-English text or code-switched content

            **Uncertainty Indicators:**
            - Confidence < 50%: Prediction is highly uncertain, consider human review
            - Top 2 predictions very close: Document may belong to multiple categories
            - Agent votes disagree significantly: Complex or ambiguous document

            ### Ethical Considerations

            This system should be used responsibly:
            - Not suitable for high-stakes decisions without human oversight
            - May perpetuate biases present in training data
            - Should be regularly monitored and updated with new data
            - Users should verify important predictions

            ### Technical Details

            **Feature Engineering:**
            - TF-IDF: 5000 features, bigrams, sublinear TF scaling
            - Embeddings: 384-dimensional sentence-transformers (all-MiniLM-L6-v2)
            - Metadata: Document length, word count, average word length

            **Model Architectures:**
            - TF-IDF Agent: Logistic Regression (L2 regularization)
            - Embedding Agent: 2-layer neural network (384 -> 256 -> 128 -> 20)
            - XGBoost Agent: 200 estimators, max depth 6, learning rate 0.1
            - Meta-learner: Logistic Regression on stacked predictions

            **Reproducibility:**
            All random seeds are set to {config.random_seed} for reproducibility.
            Training on the same data with same configuration should yield very similar results.
            """)

        with gr.Tab("About"):
            gr.Markdown("""
            ### About This System

            **Project:** Multi-Agent AI Collaboration System for Document Classification

            **Author:** Spencer Purdy

            **Purpose:** Demonstrate genuine multi-model machine learning collaboration
            for document classification and routing.

            **Real-World Applications:**
            - Customer support ticket routing
            - Email categorization
            - Content moderation
            - Document management systems
            - News article classification

            **Dataset:**
            - 20 Newsgroups dataset
            - Publicly available via Hugging Face
            - Approximately 18,000 newsgroup posts
            - 20 categories covering diverse topics
            - No personal or sensitive information

            **Technology Stack:**
            - scikit-learn: Classical ML algorithms and pipelines
            - PyTorch: Neural network implementation
            - sentence-transformers: Semantic embeddings
            - XGBoost: Gradient boosting
            - Gradio: User interface

            **Development:**
            - Developed and tested in Google Colab
            - Can be deployed to Hugging Face Spaces
            - All dependencies explicitly versioned
            - Code is documented and follows best practices

            **License:**
            - Code: MIT License
            - Dataset: Public domain (20 Newsgroups)

            **Contact:**
            For questions or issues, please contact Spencer Purdy.

            **Acknowledgments:**
            - 20 Newsgroups dataset creators
            - scikit-learn team
            - Hugging Face for sentence-transformers and dataset hosting
            - Open source ML community
            """)

    return interface

# Main execution
if __name__ == "__main__":
    logger.info("=" * 70)
    logger.info("Multi-Agent AI Collaboration System")
    logger.info("Author: Spencer Purdy")
    logger.info("=" * 70)
    logger.info(f"Random seed: {RANDOM_SEED}")
    logger.info(f"PyTorch version: {torch.__version__}")
    logger.info(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        logger.info(f"CUDA device: {torch.cuda.get_device_name(0)}")

    # Initialize system
    logger.info("\nInitializing system...")
    system = MultiAgentSystem(config)

    # Train system
    logger.info("\nStarting training process...")
    training_results = system.train_full_system()

    # Create and launch interface
    logger.info("\nCreating Gradio interface...")
    interface = create_gradio_interface(system, training_results)

    logger.info("\nLaunching interface...")
    interface.launch(
        share=True,
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True
    )