Snaseem2026 commited on
Commit
4089b4a
·
verified ·
1 Parent(s): 403d57b

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. src/__init__.py +41 -0
  2. src/data_loader.py +205 -0
  3. src/model.py +108 -0
  4. src/utils.py +255 -0
  5. src/validation.py +174 -0
src/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Initialization for src package
3
+ """
4
+ from .data_loader import load_config, prepare_datasets_for_training
5
+ from .model import (
6
+ create_model,
7
+ get_model_size,
8
+ get_trainable_params,
9
+ apply_class_weights
10
+ )
11
+ from .utils import (
12
+ compute_metrics,
13
+ compute_metrics_factory,
14
+ plot_confusion_matrix,
15
+ print_classification_report,
16
+ plot_training_curves
17
+ )
18
+ from .validation import (
19
+ validate_config,
20
+ validate_model_path,
21
+ validate_data_file,
22
+ validate_config_file
23
+ )
24
+
25
+ __all__ = [
26
+ 'load_config',
27
+ 'prepare_datasets_for_training',
28
+ 'create_model',
29
+ 'get_model_size',
30
+ 'get_trainable_params',
31
+ 'apply_class_weights',
32
+ 'compute_metrics',
33
+ 'compute_metrics_factory',
34
+ 'plot_confusion_matrix',
35
+ 'print_classification_report',
36
+ 'plot_training_curves',
37
+ 'validate_config',
38
+ 'validate_model_path',
39
+ 'validate_data_file',
40
+ 'validate_config_file'
41
+ ]
src/data_loader.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data loader utilities for Code Comment Quality Classifier
3
+ """
4
+ import pandas as pd
5
+ from datasets import Dataset, DatasetDict
6
+ from sklearn.model_selection import train_test_split
7
+ from typing import Tuple, Dict, List, Optional
8
+ import yaml
9
+ import logging
10
+ import os
11
+ from pathlib import Path
12
+
13
+
14
+ def load_config(config_path: str = "config.yaml") -> dict:
15
+ """Load configuration from YAML file."""
16
+ with open(config_path, 'r') as f:
17
+ config = yaml.safe_load(f)
18
+ return config
19
+
20
+
21
+ def load_data(data_path: str) -> pd.DataFrame:
22
+ """
23
+ Load data from CSV file with validation.
24
+
25
+ Expected format:
26
+ - comment: str (the code comment text)
27
+ - label: str (excellent, helpful, unclear, or outdated)
28
+
29
+ Args:
30
+ data_path: Path to the CSV file
31
+
32
+ Returns:
33
+ DataFrame with validated data
34
+
35
+ Raises:
36
+ FileNotFoundError: If data file doesn't exist
37
+ ValueError: If data format is invalid
38
+ """
39
+ if not os.path.exists(data_path):
40
+ raise FileNotFoundError(f"Data file not found: {data_path}")
41
+
42
+ df = pd.read_csv(data_path)
43
+
44
+ # Validate required columns
45
+ required_columns = ['comment', 'label']
46
+ missing_columns = [col for col in required_columns if col not in df.columns]
47
+ if missing_columns:
48
+ raise ValueError(f"Missing required columns: {missing_columns}")
49
+
50
+ # Remove rows with missing values
51
+ initial_len = len(df)
52
+ df = df.dropna(subset=required_columns)
53
+ if len(df) < initial_len:
54
+ logging.warning(f"Removed {initial_len - len(df)} rows with missing values")
55
+
56
+ # Remove empty comments
57
+ df = df[df['comment'].str.strip().str.len() > 0]
58
+
59
+ # Validate labels
60
+ if df['label'].isna().any():
61
+ logging.warning("Found NaN labels, removing those rows")
62
+ df = df.dropna(subset=['label'])
63
+
64
+ logging.info(f"Loaded {len(df)} samples from {data_path}")
65
+ return df
66
+
67
+
68
+ def create_label_mapping(labels: list) -> Tuple[Dict[str, int], Dict[int, str]]:
69
+ """Create bidirectional label mapping."""
70
+ label2id = {label: idx for idx, label in enumerate(labels)}
71
+ id2label = {idx: label for idx, label in enumerate(labels)}
72
+ return label2id, id2label
73
+
74
+
75
+ def prepare_dataset(
76
+ df: pd.DataFrame,
77
+ label2id: Dict[str, int],
78
+ train_size: float = 0.8,
79
+ val_size: float = 0.1,
80
+ test_size: float = 0.1,
81
+ seed: int = 42,
82
+ stratify: bool = True
83
+ ) -> DatasetDict:
84
+ """
85
+ Prepare dataset splits for training.
86
+
87
+ Args:
88
+ df: DataFrame with 'comment' and 'label' columns
89
+ label2id: Mapping from label names to IDs
90
+ train_size: Proportion of training data
91
+ val_size: Proportion of validation data
92
+ test_size: Proportion of test data
93
+ seed: Random seed for reproducibility
94
+ stratify: Whether to maintain class distribution in splits
95
+
96
+ Returns:
97
+ DatasetDict with train, validation, and test splits
98
+ """
99
+ # Validate label distribution
100
+ invalid_labels = set(df['label'].unique()) - set(label2id.keys())
101
+ if invalid_labels:
102
+ raise ValueError(f"Found invalid labels: {invalid_labels}. Expected: {list(label2id.keys())}")
103
+
104
+ # Convert labels to IDs
105
+ df['label_id'] = df['label'].map(label2id)
106
+
107
+ # Check for missing mappings
108
+ if df['label_id'].isna().any():
109
+ missing_labels = df[df['label_id'].isna()]['label'].unique()
110
+ raise ValueError(f"Labels not found in label2id mapping: {missing_labels}")
111
+
112
+ # Validate split proportions
113
+ total_size = train_size + val_size + test_size
114
+ if abs(total_size - 1.0) > 1e-6:
115
+ raise ValueError(f"Split sizes must sum to 1.0, got {total_size}")
116
+
117
+ # Stratification column
118
+ stratify_col = df['label_id'] if stratify else None
119
+
120
+ # First split: separate test set
121
+ train_val_df, test_df = train_test_split(
122
+ df,
123
+ test_size=test_size,
124
+ random_state=seed,
125
+ stratify=stratify_col
126
+ )
127
+
128
+ # Second split: separate train and validation
129
+ val_size_adjusted = val_size / (train_size + val_size)
130
+ stratify_col_train = train_val_df['label_id'] if stratify else None
131
+ train_df, val_df = train_test_split(
132
+ train_val_df,
133
+ test_size=val_size_adjusted,
134
+ random_state=seed,
135
+ stratify=stratify_col_train
136
+ )
137
+
138
+ # Log distribution
139
+ logging.info(f"Dataset splits - Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")
140
+ logging.info(f"Train label distribution:\n{train_df['label'].value_counts().sort_index()}")
141
+
142
+ # Create datasets
143
+ dataset_dict = DatasetDict({
144
+ 'train': Dataset.from_pandas(train_df[['comment', 'label_id']], preserve_index=False),
145
+ 'validation': Dataset.from_pandas(val_df[['comment', 'label_id']], preserve_index=False),
146
+ 'test': Dataset.from_pandas(test_df[['comment', 'label_id']], preserve_index=False)
147
+ })
148
+
149
+ return dataset_dict
150
+
151
+
152
+ def tokenize_function(examples, tokenizer, max_length: int = 512):
153
+ """Tokenize the input text."""
154
+ return tokenizer(
155
+ examples['comment'],
156
+ padding='max_length',
157
+ truncation=True,
158
+ max_length=max_length
159
+ )
160
+
161
+
162
+ def prepare_datasets_for_training(config_path: str = "config.yaml"):
163
+ """
164
+ Complete pipeline to prepare datasets for training.
165
+
166
+ Returns:
167
+ Tuple of (tokenized_datasets, label2id, id2label, tokenizer)
168
+ """
169
+ from transformers import AutoTokenizer
170
+
171
+ config = load_config(config_path)
172
+
173
+ # Load data
174
+ df = load_data(config['data']['data_path'])
175
+
176
+ # Create label mappings
177
+ labels = config['labels']
178
+ label2id, id2label = create_label_mapping(labels)
179
+
180
+ # Prepare dataset splits
181
+ stratify = config['data'].get('stratify', True)
182
+ dataset_dict = prepare_dataset(
183
+ df,
184
+ label2id,
185
+ train_size=config['data']['train_size'],
186
+ val_size=config['data']['val_size'],
187
+ test_size=config['data']['test_size'],
188
+ seed=config['training']['seed'],
189
+ stratify=stratify
190
+ )
191
+
192
+ # Load tokenizer
193
+ tokenizer = AutoTokenizer.from_pretrained(config['model']['name'])
194
+
195
+ # Tokenize datasets
196
+ tokenized_datasets = dataset_dict.map(
197
+ lambda x: tokenize_function(x, tokenizer, config['model']['max_length']),
198
+ batched=True,
199
+ remove_columns=['comment']
200
+ )
201
+
202
+ # Rename label_id to labels for training
203
+ tokenized_datasets = tokenized_datasets.rename_column('label_id', 'labels')
204
+
205
+ return tokenized_datasets, label2id, id2label, tokenizer
src/model.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model definition and utilities
3
+ """
4
+ from transformers import AutoModelForSequenceClassification, AutoConfig
5
+ from typing import Dict, Optional
6
+ import logging
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ def create_model(
12
+ model_name: str,
13
+ num_labels: int,
14
+ label2id: Dict[str, int],
15
+ id2label: Dict[int, str],
16
+ dropout: Optional[float] = None
17
+ ):
18
+ """
19
+ Create a sequence classification model with optional dropout configuration.
20
+
21
+ Args:
22
+ model_name: Name of the pretrained model
23
+ num_labels: Number of classification labels
24
+ label2id: Mapping from label names to IDs
25
+ id2label: Mapping from IDs to label names
26
+ dropout: Optional dropout probability for classifier head
27
+
28
+ Returns:
29
+ Initialized model
30
+ """
31
+ config = AutoConfig.from_pretrained(
32
+ model_name,
33
+ num_labels=num_labels,
34
+ label2id=label2id,
35
+ id2label=id2label
36
+ )
37
+
38
+ # Set dropout if provided
39
+ if dropout is not None:
40
+ if hasattr(config, 'hidden_dropout_prob'):
41
+ config.hidden_dropout_prob = dropout
42
+ if hasattr(config, 'attention_probs_dropout_prob'):
43
+ config.attention_probs_dropout_prob = dropout
44
+ if hasattr(config, 'classifier_dropout'):
45
+ config.classifier_dropout = dropout
46
+ logging.info(f"Set model dropout to {dropout}")
47
+
48
+ model = AutoModelForSequenceClassification.from_pretrained(
49
+ model_name,
50
+ config=config
51
+ )
52
+
53
+ return model
54
+
55
+
56
+ def apply_class_weights(
57
+ model: nn.Module,
58
+ class_weights: Optional[list] = None
59
+ ) -> Optional[nn.Module]:
60
+ """
61
+ Apply class weights to the model's loss function.
62
+
63
+ Args:
64
+ model: The model to modify
65
+ class_weights: List of weights for each class (must match num_labels)
66
+
67
+ Returns:
68
+ Model with modified loss function (if class_weights provided)
69
+ """
70
+ if class_weights is not None:
71
+ weights_tensor = torch.tensor(class_weights, dtype=torch.float32)
72
+ # Note: This requires custom Trainer with weighted loss
73
+ logging.info(f"Class weights applied: {class_weights}")
74
+ return weights_tensor
75
+ return None
76
+
77
+
78
+ def get_model_size(model: nn.Module) -> float:
79
+ """
80
+ Calculate model size in millions of parameters.
81
+
82
+ Args:
83
+ model: PyTorch model
84
+
85
+ Returns:
86
+ Number of parameters in millions
87
+ """
88
+ param_size = sum(p.numel() for p in model.parameters())
89
+ return param_size / 1e6
90
+
91
+
92
+ def get_trainable_params(model: nn.Module) -> Dict[str, int]:
93
+ """
94
+ Get count of trainable and non-trainable parameters.
95
+
96
+ Args:
97
+ model: PyTorch model
98
+
99
+ Returns:
100
+ Dictionary with 'trainable' and 'total' parameter counts
101
+ """
102
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
103
+ total = sum(p.numel() for p in model.parameters())
104
+ return {
105
+ 'trainable': trainable,
106
+ 'total': total,
107
+ 'non_trainable': total - trainable
108
+ }
src/utils.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for training and evaluation
3
+ """
4
+ import numpy as np
5
+ from sklearn.metrics import (
6
+ accuracy_score,
7
+ precision_recall_fscore_support,
8
+ confusion_matrix,
9
+ classification_report
10
+ )
11
+ import matplotlib.pyplot as plt
12
+ import seaborn as sns
13
+ from typing import Dict, Tuple, List, Optional
14
+ import os
15
+
16
+
17
+ def compute_metrics(eval_pred, id2label: Optional[Dict[int, str]] = None) -> Dict[str, float]:
18
+ """
19
+ Compute comprehensive metrics for evaluation.
20
+
21
+ Args:
22
+ eval_pred: Tuple of (predictions, labels)
23
+ id2label: Optional mapping from label IDs to label names for per-class metrics
24
+
25
+ Returns:
26
+ Dictionary of metrics including overall and per-class metrics
27
+ """
28
+ predictions, labels = eval_pred
29
+ predictions = np.argmax(predictions, axis=1)
30
+
31
+ # Overall metrics
32
+ accuracy = accuracy_score(labels, predictions)
33
+
34
+ # Weighted metrics (accounts for class imbalance)
35
+ precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
36
+ labels,
37
+ predictions,
38
+ average='weighted',
39
+ zero_division=0
40
+ )
41
+
42
+ # Macro-averaged metrics (treats all classes equally)
43
+ precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
44
+ labels,
45
+ predictions,
46
+ average='macro',
47
+ zero_division=0
48
+ )
49
+
50
+ # Micro-averaged metrics (aggregates contributions of all classes)
51
+ precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
52
+ labels,
53
+ predictions,
54
+ average='micro',
55
+ zero_division=0
56
+ )
57
+
58
+ metrics = {
59
+ 'accuracy': accuracy,
60
+ 'precision_weighted': precision_weighted,
61
+ 'recall_weighted': recall_weighted,
62
+ 'f1_weighted': f1_weighted,
63
+ 'precision_macro': precision_macro,
64
+ 'recall_macro': recall_macro,
65
+ 'f1_macro': f1_macro,
66
+ 'precision_micro': precision_micro,
67
+ 'recall_micro': recall_micro,
68
+ 'f1_micro': f1_micro,
69
+ }
70
+
71
+ # Per-class metrics if label mapping is provided
72
+ if id2label is not None:
73
+ num_classes = len(id2label)
74
+ precision_per_class, recall_per_class, f1_per_class, support = precision_recall_fscore_support(
75
+ labels,
76
+ predictions,
77
+ labels=list(range(num_classes)),
78
+ average=None,
79
+ zero_division=0
80
+ )
81
+
82
+ for i in range(num_classes):
83
+ label_name = id2label[i]
84
+ metrics[f'precision_{label_name}'] = float(precision_per_class[i])
85
+ metrics[f'recall_{label_name}'] = float(recall_per_class[i])
86
+ metrics[f'f1_{label_name}'] = float(f1_per_class[i])
87
+ metrics[f'support_{label_name}'] = int(support[i])
88
+
89
+ return metrics
90
+
91
+
92
+ def compute_metrics_factory(id2label: Optional[Dict[int, str]] = None):
93
+ """
94
+ Factory function to create compute_metrics with label mapping.
95
+
96
+ Args:
97
+ id2label: Mapping from label IDs to label names
98
+
99
+ Returns:
100
+ Function compatible with HuggingFace Trainer
101
+ """
102
+ def compute_metrics_fn(eval_pred):
103
+ return compute_metrics(eval_pred, id2label)
104
+
105
+ return compute_metrics_fn
106
+
107
+
108
+ def plot_confusion_matrix(
109
+ y_true: np.ndarray,
110
+ y_pred: np.ndarray,
111
+ labels: List[str],
112
+ save_path: str = "confusion_matrix.png",
113
+ normalize: bool = False,
114
+ figsize: Tuple[int, int] = (10, 8)
115
+ ) -> None:
116
+ """
117
+ Plot and save confusion matrix with optional normalization.
118
+
119
+ Args:
120
+ y_true: True labels
121
+ y_pred: Predicted labels
122
+ labels: List of label names
123
+ save_path: Path to save the plot
124
+ normalize: If True, normalize confusion matrix to percentages
125
+ figsize: Figure size (width, height)
126
+ """
127
+ cm = confusion_matrix(y_true, y_pred, labels=list(range(len(labels))))
128
+
129
+ if normalize:
130
+ cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
131
+ fmt = '.2f'
132
+ title = 'Normalized Confusion Matrix'
133
+ else:
134
+ fmt = 'd'
135
+ title = 'Confusion Matrix'
136
+
137
+ plt.figure(figsize=figsize)
138
+ sns.heatmap(
139
+ cm,
140
+ annot=True,
141
+ fmt=fmt,
142
+ cmap='Blues',
143
+ xticklabels=labels,
144
+ yticklabels=labels,
145
+ cbar_kws={'label': 'Percentage' if normalize else 'Count'}
146
+ )
147
+ plt.title(title, fontsize=14, fontweight='bold')
148
+ plt.ylabel('True Label', fontsize=12)
149
+ plt.xlabel('Predicted Label', fontsize=12)
150
+ plt.tight_layout()
151
+
152
+ # Create directory if it doesn't exist
153
+ os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)
154
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
155
+ plt.close()
156
+
157
+ print(f"Confusion matrix saved to {save_path}")
158
+
159
+
160
+ def print_classification_report(
161
+ y_true: np.ndarray,
162
+ y_pred: np.ndarray,
163
+ labels: List[str],
164
+ output_dict: bool = False
165
+ ) -> Optional[Dict]:
166
+ """
167
+ Print detailed classification report.
168
+
169
+ Args:
170
+ y_true: True labels
171
+ y_pred: Predicted labels
172
+ labels: List of label names
173
+ output_dict: If True, return report as dictionary instead of printing
174
+
175
+ Returns:
176
+ Classification report as dictionary if output_dict=True, else None
177
+ """
178
+ report = classification_report(
179
+ y_true,
180
+ y_pred,
181
+ target_names=labels,
182
+ digits=4,
183
+ output_dict=output_dict,
184
+ zero_division=0
185
+ )
186
+
187
+ if output_dict:
188
+ return report
189
+
190
+ print("\nClassification Report:")
191
+ print("=" * 60)
192
+ print(report)
193
+ return None
194
+
195
+
196
+ def plot_training_curves(
197
+ train_losses: List[float],
198
+ eval_losses: List[float],
199
+ eval_metrics: Dict[str, List[float]],
200
+ save_path: str = "./results/training_curves.png"
201
+ ) -> None:
202
+ """
203
+ Plot training and evaluation curves.
204
+
205
+ Args:
206
+ train_losses: List of training losses per step/epoch
207
+ eval_losses: List of evaluation losses per step/epoch
208
+ eval_metrics: Dictionary of metric names to lists of values
209
+ save_path: Path to save the plot
210
+ """
211
+ fig, axes = plt.subplots(2, 2, figsize=(15, 10))
212
+
213
+ # Loss curves
214
+ axes[0, 0].plot(train_losses, label='Train Loss', color='blue')
215
+ axes[0, 0].plot(eval_losses, label='Eval Loss', color='red')
216
+ axes[0, 0].set_xlabel('Step/Epoch')
217
+ axes[0, 0].set_ylabel('Loss')
218
+ axes[0, 0].set_title('Training and Validation Loss')
219
+ axes[0, 0].legend()
220
+ axes[0, 0].grid(True, alpha=0.3)
221
+
222
+ # Accuracy
223
+ if 'accuracy' in eval_metrics:
224
+ axes[0, 1].plot(eval_metrics['accuracy'], label='Accuracy', color='green')
225
+ axes[0, 1].set_xlabel('Step/Epoch')
226
+ axes[0, 1].set_ylabel('Accuracy')
227
+ axes[0, 1].set_title('Validation Accuracy')
228
+ axes[0, 1].legend()
229
+ axes[0, 1].grid(True, alpha=0.3)
230
+
231
+ # F1 Score
232
+ if 'f1_weighted' in eval_metrics:
233
+ axes[1, 0].plot(eval_metrics['f1_weighted'], label='F1 (weighted)', color='purple')
234
+ axes[1, 0].set_xlabel('Step/Epoch')
235
+ axes[1, 0].set_ylabel('F1 Score')
236
+ axes[1, 0].set_title('Validation F1 Score')
237
+ axes[1, 0].legend()
238
+ axes[1, 0].grid(True, alpha=0.3)
239
+
240
+ # Precision and Recall
241
+ if 'precision_weighted' in eval_metrics and 'recall_weighted' in eval_metrics:
242
+ axes[1, 1].plot(eval_metrics['precision_weighted'], label='Precision', color='orange')
243
+ axes[1, 1].plot(eval_metrics['recall_weighted'], label='Recall', color='cyan')
244
+ axes[1, 1].set_xlabel('Step/Epoch')
245
+ axes[1, 1].set_ylabel('Score')
246
+ axes[1, 1].set_title('Validation Precision and Recall')
247
+ axes[1, 1].legend()
248
+ axes[1, 1].grid(True, alpha=0.3)
249
+
250
+ plt.tight_layout()
251
+ os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)
252
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
253
+ plt.close()
254
+
255
+ print(f"Training curves saved to {save_path}")
src/validation.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Validation utilities for model and data validation
3
+ """
4
+ import os
5
+ import yaml
6
+ from typing import Dict, List, Optional
7
+ import logging
8
+ from pathlib import Path
9
+
10
+
11
+ def validate_config(config: Dict) -> List[str]:
12
+ """
13
+ Validate configuration file for common issues.
14
+
15
+ Args:
16
+ config: Configuration dictionary
17
+
18
+ Returns:
19
+ List of validation error messages (empty if valid)
20
+ """
21
+ errors = []
22
+
23
+ # Check required sections
24
+ required_sections = ['model', 'training', 'data', 'labels']
25
+ for section in required_sections:
26
+ if section not in config:
27
+ errors.append(f"Missing required section: {section}")
28
+
29
+ if errors:
30
+ return errors
31
+
32
+ # Validate model section
33
+ if 'name' not in config['model']:
34
+ errors.append("model.name is required")
35
+ if 'num_labels' not in config['model']:
36
+ errors.append("model.num_labels is required")
37
+ elif config['model']['num_labels'] != len(config.get('labels', [])):
38
+ errors.append(f"model.num_labels ({config['model']['num_labels']}) doesn't match number of labels ({len(config['labels'])})")
39
+
40
+ # Validate training section
41
+ training = config['training']
42
+ if 'num_train_epochs' in training and training['num_train_epochs'] <= 0:
43
+ errors.append("training.num_train_epochs must be positive")
44
+ if 'learning_rate' in training and training['learning_rate'] <= 0:
45
+ errors.append("training.learning_rate must be positive")
46
+ if 'per_device_train_batch_size' in training and training['per_device_train_batch_size'] <= 0:
47
+ errors.append("training.per_device_train_batch_size must be positive")
48
+
49
+ # Validate data section
50
+ data = config['data']
51
+ if 'data_path' in data and not os.path.exists(data['data_path']):
52
+ errors.append(f"Data file not found: {data['data_path']}")
53
+
54
+ train_size = data.get('train_size', 0)
55
+ val_size = data.get('val_size', 0)
56
+ test_size = data.get('test_size', 0)
57
+ total = train_size + val_size + test_size
58
+ if abs(total - 1.0) > 1e-6:
59
+ errors.append(f"Data split sizes must sum to 1.0, got {total}")
60
+
61
+ # Validate labels
62
+ if 'labels' not in config or not config['labels']:
63
+ errors.append("labels section is required and cannot be empty")
64
+ elif len(set(config['labels'])) != len(config['labels']):
65
+ errors.append("labels must be unique")
66
+
67
+ return errors
68
+
69
+
70
+ def validate_model_path(model_path: str) -> bool:
71
+ """
72
+ Validate that model path exists and contains required files.
73
+
74
+ Args:
75
+ model_path: Path to model directory
76
+
77
+ Returns:
78
+ True if valid, False otherwise
79
+ """
80
+ if not os.path.exists(model_path):
81
+ logging.error(f"Model path does not exist: {model_path}")
82
+ return False
83
+
84
+ required_files = ['config.json']
85
+ for file in required_files:
86
+ file_path = os.path.join(model_path, file)
87
+ if not os.path.exists(file_path):
88
+ logging.error(f"Required file missing: {file_path}")
89
+ return False
90
+
91
+ return True
92
+
93
+
94
+ def validate_data_file(data_path: str, required_columns: List[str] = None) -> List[str]:
95
+ """
96
+ Validate data file format and content.
97
+
98
+ Args:
99
+ data_path: Path to data file
100
+ required_columns: List of required column names
101
+
102
+ Returns:
103
+ List of validation error messages (empty if valid)
104
+ """
105
+ errors = []
106
+
107
+ if required_columns is None:
108
+ required_columns = ['comment', 'label']
109
+
110
+ if not os.path.exists(data_path):
111
+ errors.append(f"Data file not found: {data_path}")
112
+ return errors
113
+
114
+ try:
115
+ import pandas as pd
116
+ df = pd.read_csv(data_path)
117
+
118
+ # Check required columns
119
+ missing_columns = [col for col in required_columns if col not in df.columns]
120
+ if missing_columns:
121
+ errors.append(f"Missing required columns: {missing_columns}")
122
+
123
+ # Check for empty dataframe
124
+ if len(df) == 0:
125
+ errors.append("Data file is empty")
126
+
127
+ # Check for missing values in required columns
128
+ if 'comment' in df.columns:
129
+ empty_comments = df['comment'].isna().sum() + (df['comment'].str.strip().str.len() == 0).sum()
130
+ if empty_comments > 0:
131
+ errors.append(f"Found {empty_comments} empty comments")
132
+
133
+ if 'label' in df.columns:
134
+ missing_labels = df['label'].isna().sum()
135
+ if missing_labels > 0:
136
+ errors.append(f"Found {missing_labels} missing labels")
137
+
138
+ except Exception as e:
139
+ errors.append(f"Error reading data file: {str(e)}")
140
+
141
+ return errors
142
+
143
+
144
+ def validate_config_file(config_path: str) -> bool:
145
+ """
146
+ Validate configuration file.
147
+
148
+ Args:
149
+ config_path: Path to configuration file
150
+
151
+ Returns:
152
+ True if valid, False otherwise
153
+ """
154
+ if not os.path.exists(config_path):
155
+ logging.error(f"Config file not found: {config_path}")
156
+ return False
157
+
158
+ try:
159
+ with open(config_path, 'r') as f:
160
+ config = yaml.safe_load(f)
161
+
162
+ errors = validate_config(config)
163
+ if errors:
164
+ logging.error("Configuration validation errors:")
165
+ for error in errors:
166
+ logging.error(f" - {error}")
167
+ return False
168
+
169
+ logging.info("Configuration file is valid")
170
+ return True
171
+
172
+ except Exception as e:
173
+ logging.error(f"Error reading config file: {str(e)}")
174
+ return False