SpencerCPurdy commited on
Commit
ff01b81
·
verified ·
1 Parent(s): 29b2b91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -24
app.py CHANGED
@@ -66,7 +66,7 @@ if torch.cuda.is_available():
66
  # Core libraries
67
  import pandas as pd
68
  import numpy as np
69
- from sklearn.datasets import fetch_20newsgroups
70
  from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
71
  from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
72
  from sklearn.preprocessing import LabelEncoder, StandardScaler
@@ -185,7 +185,7 @@ class NewsGroupsDataLoader:
185
  Loads and preprocesses the 20 Newsgroups dataset.
186
 
187
  Dataset Information:
188
- - Source: 20 Newsgroups dataset (publicly available via scikit-learn)
189
  - License: Public domain
190
  - Size: ~18,000 newsgroup posts across 20 categories
191
  - Task: Multi-class text classification
@@ -208,26 +208,21 @@ class NewsGroupsDataLoader:
208
  Returns:
209
  Tuple of (train_df, val_df, test_df)
210
  """
211
- logger.info("Loading 20 Newsgroups dataset...")
212
-
213
- # Load training data
214
- train_data = fetch_20newsgroups(
215
- subset='train',
216
- remove=('headers', 'footers', 'quotes'),
217
- random_state=self.config.random_seed
218
- )
219
-
220
- # Load test data
221
- test_data = fetch_20newsgroups(
222
- subset='test',
223
- remove=('headers', 'footers', 'quotes'),
224
- random_state=self.config.random_seed
225
- )
226
-
227
  # Combine for proper splitting
228
- all_texts = list(train_data.data) + list(test_data.data)
229
- all_labels = list(train_data.target) + list(test_data.target)
230
- self.categories = train_data.target_names
 
 
231
 
232
  logger.info(f"Total documents: {len(all_texts)}")
233
  logger.info(f"Number of categories: {len(self.categories)}")
@@ -1326,7 +1321,7 @@ def create_gradio_interface(system: MultiAgentSystem, training_results: Dict):
1326
 
1327
  gr.Markdown("""
1328
  ### Model Limitations and Failure Cases
1329
-
1330
  **Known Limitations:**
1331
  1. **Domain Specificity**: Trained on newsgroup data, may not generalize well to
1332
  significantly different domains (e.g., legal documents, medical reports)
@@ -1393,7 +1388,7 @@ def create_gradio_interface(system: MultiAgentSystem, training_results: Dict):
1393
 
1394
  **Dataset:**
1395
  - 20 Newsgroups dataset
1396
- - Publicly available via scikit-learn
1397
  - Approximately 18,000 newsgroup posts
1398
  - 20 categories covering diverse topics
1399
  - No personal or sensitive information
@@ -1421,7 +1416,7 @@ def create_gradio_interface(system: MultiAgentSystem, training_results: Dict):
1421
  **Acknowledgments:**
1422
  - 20 Newsgroups dataset creators
1423
  - scikit-learn team
1424
- - Hugging Face for sentence-transformers
1425
  - Open source ML community
1426
  """)
1427
 
 
66
  # Core libraries
67
  import pandas as pd
68
  import numpy as np
69
+ from datasets import load_dataset
70
  from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
71
  from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
72
  from sklearn.preprocessing import LabelEncoder, StandardScaler
 
185
  Loads and preprocesses the 20 Newsgroups dataset.
186
 
187
  Dataset Information:
188
+ - Source: 20 Newsgroups dataset (publicly available via Hugging Face)
189
  - License: Public domain
190
  - Size: ~18,000 newsgroup posts across 20 categories
191
  - Task: Multi-class text classification
 
208
  Returns:
209
  Tuple of (train_df, val_df, test_df)
210
  """
211
+ logger.info("Loading 20 Newsgroups dataset from Hugging Face...")
212
+
213
+ # Load dataset from Hugging Face
214
+ dataset = load_dataset("SetFit/20_newsgroups")
215
+
216
+ # Extract train and test data
217
+ train_data = dataset['train']
218
+ test_data = dataset['test']
219
+
 
 
 
 
 
 
 
220
  # Combine for proper splitting
221
+ all_texts = train_data['text'] + test_data['text']
222
+ all_labels = train_data['label'] + test_data['label']
223
+
224
+ # Get category names from dataset features
225
+ self.categories = train_data.features['label'].names
226
 
227
  logger.info(f"Total documents: {len(all_texts)}")
228
  logger.info(f"Number of categories: {len(self.categories)}")
 
1321
 
1322
  gr.Markdown("""
1323
  ### Model Limitations and Failure Cases
1324
+
1325
  **Known Limitations:**
1326
  1. **Domain Specificity**: Trained on newsgroup data, may not generalize well to
1327
  significantly different domains (e.g., legal documents, medical reports)
 
1388
 
1389
  **Dataset:**
1390
  - 20 Newsgroups dataset
1391
+ - Publicly available via Hugging Face
1392
  - Approximately 18,000 newsgroup posts
1393
  - 20 categories covering diverse topics
1394
  - No personal or sensitive information
 
1416
  **Acknowledgments:**
1417
  - 20 Newsgroups dataset creators
1418
  - scikit-learn team
1419
+ - Hugging Face for sentence-transformers and dataset hosting
1420
  - Open source ML community
1421
  """)
1422