Update app.py
Browse files
app.py
CHANGED
|
@@ -66,7 +66,7 @@ if torch.cuda.is_available():
|
|
| 66 |
# Core libraries
|
| 67 |
import pandas as pd
|
| 68 |
import numpy as np
|
| 69 |
-
from
|
| 70 |
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
|
| 71 |
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
|
| 72 |
from sklearn.preprocessing import LabelEncoder, StandardScaler
|
|
@@ -185,7 +185,7 @@ class NewsGroupsDataLoader:
|
|
| 185 |
Loads and preprocesses the 20 Newsgroups dataset.
|
| 186 |
|
| 187 |
Dataset Information:
|
| 188 |
-
- Source: 20 Newsgroups dataset (publicly available via
|
| 189 |
- License: Public domain
|
| 190 |
- Size: ~18,000 newsgroup posts across 20 categories
|
| 191 |
- Task: Multi-class text classification
|
|
@@ -208,26 +208,21 @@ class NewsGroupsDataLoader:
|
|
| 208 |
Returns:
|
| 209 |
Tuple of (train_df, val_df, test_df)
|
| 210 |
"""
|
| 211 |
-
logger.info("Loading 20 Newsgroups dataset...")
|
| 212 |
-
|
| 213 |
-
# Load
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
# Load test data
|
| 221 |
-
test_data = fetch_20newsgroups(
|
| 222 |
-
subset='test',
|
| 223 |
-
remove=('headers', 'footers', 'quotes'),
|
| 224 |
-
random_state=self.config.random_seed
|
| 225 |
-
)
|
| 226 |
-
|
| 227 |
# Combine for proper splitting
|
| 228 |
-
all_texts =
|
| 229 |
-
all_labels =
|
| 230 |
-
|
|
|
|
|
|
|
| 231 |
|
| 232 |
logger.info(f"Total documents: {len(all_texts)}")
|
| 233 |
logger.info(f"Number of categories: {len(self.categories)}")
|
|
@@ -1326,7 +1321,7 @@ def create_gradio_interface(system: MultiAgentSystem, training_results: Dict):
|
|
| 1326 |
|
| 1327 |
gr.Markdown("""
|
| 1328 |
### Model Limitations and Failure Cases
|
| 1329 |
-
|
| 1330 |
**Known Limitations:**
|
| 1331 |
1. **Domain Specificity**: Trained on newsgroup data, may not generalize well to
|
| 1332 |
significantly different domains (e.g., legal documents, medical reports)
|
|
@@ -1393,7 +1388,7 @@ def create_gradio_interface(system: MultiAgentSystem, training_results: Dict):
|
|
| 1393 |
|
| 1394 |
**Dataset:**
|
| 1395 |
- 20 Newsgroups dataset
|
| 1396 |
-
- Publicly available via
|
| 1397 |
- Approximately 18,000 newsgroup posts
|
| 1398 |
- 20 categories covering diverse topics
|
| 1399 |
- No personal or sensitive information
|
|
@@ -1421,7 +1416,7 @@ def create_gradio_interface(system: MultiAgentSystem, training_results: Dict):
|
|
| 1421 |
**Acknowledgments:**
|
| 1422 |
- 20 Newsgroups dataset creators
|
| 1423 |
- scikit-learn team
|
| 1424 |
-
- Hugging Face for sentence-transformers
|
| 1425 |
- Open source ML community
|
| 1426 |
""")
|
| 1427 |
|
|
|
|
| 66 |
# Core libraries
|
| 67 |
import pandas as pd
|
| 68 |
import numpy as np
|
| 69 |
+
from datasets import load_dataset
|
| 70 |
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
|
| 71 |
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
|
| 72 |
from sklearn.preprocessing import LabelEncoder, StandardScaler
|
|
|
|
| 185 |
Loads and preprocesses the 20 Newsgroups dataset.
|
| 186 |
|
| 187 |
Dataset Information:
|
| 188 |
+
- Source: 20 Newsgroups dataset (publicly available via Hugging Face)
|
| 189 |
- License: Public domain
|
| 190 |
- Size: ~18,000 newsgroup posts across 20 categories
|
| 191 |
- Task: Multi-class text classification
|
|
|
|
| 208 |
Returns:
|
| 209 |
Tuple of (train_df, val_df, test_df)
|
| 210 |
"""
|
| 211 |
+
logger.info("Loading 20 Newsgroups dataset from Hugging Face...")
|
| 212 |
+
|
| 213 |
+
# Load dataset from Hugging Face
|
| 214 |
+
dataset = load_dataset("SetFit/20_newsgroups")
|
| 215 |
+
|
| 216 |
+
# Extract train and test data
|
| 217 |
+
train_data = dataset['train']
|
| 218 |
+
test_data = dataset['test']
|
| 219 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
# Combine for proper splitting
|
| 221 |
+
all_texts = train_data['text'] + test_data['text']
|
| 222 |
+
all_labels = train_data['label'] + test_data['label']
|
| 223 |
+
|
| 224 |
+
# Get category names from dataset features
|
| 225 |
+
self.categories = train_data.features['label'].names
|
| 226 |
|
| 227 |
logger.info(f"Total documents: {len(all_texts)}")
|
| 228 |
logger.info(f"Number of categories: {len(self.categories)}")
|
|
|
|
| 1321 |
|
| 1322 |
gr.Markdown("""
|
| 1323 |
### Model Limitations and Failure Cases
|
| 1324 |
+
|
| 1325 |
**Known Limitations:**
|
| 1326 |
1. **Domain Specificity**: Trained on newsgroup data, may not generalize well to
|
| 1327 |
significantly different domains (e.g., legal documents, medical reports)
|
|
|
|
| 1388 |
|
| 1389 |
**Dataset:**
|
| 1390 |
- 20 Newsgroups dataset
|
| 1391 |
+
- Publicly available via Hugging Face
|
| 1392 |
- Approximately 18,000 newsgroup posts
|
| 1393 |
- 20 categories covering diverse topics
|
| 1394 |
- No personal or sensitive information
|
|
|
|
| 1416 |
**Acknowledgments:**
|
| 1417 |
- 20 Newsgroups dataset creators
|
| 1418 |
- scikit-learn team
|
| 1419 |
+
- Hugging Face for sentence-transformers and dataset hosting
|
| 1420 |
- Open source ML community
|
| 1421 |
""")
|
| 1422 |
|