Upload folder using huggingface_hub
Browse files- src/__init__.py +41 -0
- src/data_loader.py +205 -0
- src/model.py +108 -0
- src/utils.py +255 -0
- src/validation.py +174 -0
src/__init__.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Initialization for src package
|
| 3 |
+
"""
|
| 4 |
+
from .data_loader import load_config, prepare_datasets_for_training
|
| 5 |
+
from .model import (
|
| 6 |
+
create_model,
|
| 7 |
+
get_model_size,
|
| 8 |
+
get_trainable_params,
|
| 9 |
+
apply_class_weights
|
| 10 |
+
)
|
| 11 |
+
from .utils import (
|
| 12 |
+
compute_metrics,
|
| 13 |
+
compute_metrics_factory,
|
| 14 |
+
plot_confusion_matrix,
|
| 15 |
+
print_classification_report,
|
| 16 |
+
plot_training_curves
|
| 17 |
+
)
|
| 18 |
+
from .validation import (
|
| 19 |
+
validate_config,
|
| 20 |
+
validate_model_path,
|
| 21 |
+
validate_data_file,
|
| 22 |
+
validate_config_file
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
__all__ = [
|
| 26 |
+
'load_config',
|
| 27 |
+
'prepare_datasets_for_training',
|
| 28 |
+
'create_model',
|
| 29 |
+
'get_model_size',
|
| 30 |
+
'get_trainable_params',
|
| 31 |
+
'apply_class_weights',
|
| 32 |
+
'compute_metrics',
|
| 33 |
+
'compute_metrics_factory',
|
| 34 |
+
'plot_confusion_matrix',
|
| 35 |
+
'print_classification_report',
|
| 36 |
+
'plot_training_curves',
|
| 37 |
+
'validate_config',
|
| 38 |
+
'validate_model_path',
|
| 39 |
+
'validate_data_file',
|
| 40 |
+
'validate_config_file'
|
| 41 |
+
]
|
src/data_loader.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data loader utilities for Code Comment Quality Classifier
|
| 3 |
+
"""
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from datasets import Dataset, DatasetDict
|
| 6 |
+
from sklearn.model_selection import train_test_split
|
| 7 |
+
from typing import Tuple, Dict, List, Optional
|
| 8 |
+
import yaml
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_config(config_path: str = "config.yaml") -> dict:
|
| 15 |
+
"""Load configuration from YAML file."""
|
| 16 |
+
with open(config_path, 'r') as f:
|
| 17 |
+
config = yaml.safe_load(f)
|
| 18 |
+
return config
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_data(data_path: str) -> pd.DataFrame:
|
| 22 |
+
"""
|
| 23 |
+
Load data from CSV file with validation.
|
| 24 |
+
|
| 25 |
+
Expected format:
|
| 26 |
+
- comment: str (the code comment text)
|
| 27 |
+
- label: str (excellent, helpful, unclear, or outdated)
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
data_path: Path to the CSV file
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
DataFrame with validated data
|
| 34 |
+
|
| 35 |
+
Raises:
|
| 36 |
+
FileNotFoundError: If data file doesn't exist
|
| 37 |
+
ValueError: If data format is invalid
|
| 38 |
+
"""
|
| 39 |
+
if not os.path.exists(data_path):
|
| 40 |
+
raise FileNotFoundError(f"Data file not found: {data_path}")
|
| 41 |
+
|
| 42 |
+
df = pd.read_csv(data_path)
|
| 43 |
+
|
| 44 |
+
# Validate required columns
|
| 45 |
+
required_columns = ['comment', 'label']
|
| 46 |
+
missing_columns = [col for col in required_columns if col not in df.columns]
|
| 47 |
+
if missing_columns:
|
| 48 |
+
raise ValueError(f"Missing required columns: {missing_columns}")
|
| 49 |
+
|
| 50 |
+
# Remove rows with missing values
|
| 51 |
+
initial_len = len(df)
|
| 52 |
+
df = df.dropna(subset=required_columns)
|
| 53 |
+
if len(df) < initial_len:
|
| 54 |
+
logging.warning(f"Removed {initial_len - len(df)} rows with missing values")
|
| 55 |
+
|
| 56 |
+
# Remove empty comments
|
| 57 |
+
df = df[df['comment'].str.strip().str.len() > 0]
|
| 58 |
+
|
| 59 |
+
# Validate labels
|
| 60 |
+
if df['label'].isna().any():
|
| 61 |
+
logging.warning("Found NaN labels, removing those rows")
|
| 62 |
+
df = df.dropna(subset=['label'])
|
| 63 |
+
|
| 64 |
+
logging.info(f"Loaded {len(df)} samples from {data_path}")
|
| 65 |
+
return df
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def create_label_mapping(labels: list) -> Tuple[Dict[str, int], Dict[int, str]]:
|
| 69 |
+
"""Create bidirectional label mapping."""
|
| 70 |
+
label2id = {label: idx for idx, label in enumerate(labels)}
|
| 71 |
+
id2label = {idx: label for idx, label in enumerate(labels)}
|
| 72 |
+
return label2id, id2label
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def prepare_dataset(
|
| 76 |
+
df: pd.DataFrame,
|
| 77 |
+
label2id: Dict[str, int],
|
| 78 |
+
train_size: float = 0.8,
|
| 79 |
+
val_size: float = 0.1,
|
| 80 |
+
test_size: float = 0.1,
|
| 81 |
+
seed: int = 42,
|
| 82 |
+
stratify: bool = True
|
| 83 |
+
) -> DatasetDict:
|
| 84 |
+
"""
|
| 85 |
+
Prepare dataset splits for training.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
df: DataFrame with 'comment' and 'label' columns
|
| 89 |
+
label2id: Mapping from label names to IDs
|
| 90 |
+
train_size: Proportion of training data
|
| 91 |
+
val_size: Proportion of validation data
|
| 92 |
+
test_size: Proportion of test data
|
| 93 |
+
seed: Random seed for reproducibility
|
| 94 |
+
stratify: Whether to maintain class distribution in splits
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
DatasetDict with train, validation, and test splits
|
| 98 |
+
"""
|
| 99 |
+
# Validate label distribution
|
| 100 |
+
invalid_labels = set(df['label'].unique()) - set(label2id.keys())
|
| 101 |
+
if invalid_labels:
|
| 102 |
+
raise ValueError(f"Found invalid labels: {invalid_labels}. Expected: {list(label2id.keys())}")
|
| 103 |
+
|
| 104 |
+
# Convert labels to IDs
|
| 105 |
+
df['label_id'] = df['label'].map(label2id)
|
| 106 |
+
|
| 107 |
+
# Check for missing mappings
|
| 108 |
+
if df['label_id'].isna().any():
|
| 109 |
+
missing_labels = df[df['label_id'].isna()]['label'].unique()
|
| 110 |
+
raise ValueError(f"Labels not found in label2id mapping: {missing_labels}")
|
| 111 |
+
|
| 112 |
+
# Validate split proportions
|
| 113 |
+
total_size = train_size + val_size + test_size
|
| 114 |
+
if abs(total_size - 1.0) > 1e-6:
|
| 115 |
+
raise ValueError(f"Split sizes must sum to 1.0, got {total_size}")
|
| 116 |
+
|
| 117 |
+
# Stratification column
|
| 118 |
+
stratify_col = df['label_id'] if stratify else None
|
| 119 |
+
|
| 120 |
+
# First split: separate test set
|
| 121 |
+
train_val_df, test_df = train_test_split(
|
| 122 |
+
df,
|
| 123 |
+
test_size=test_size,
|
| 124 |
+
random_state=seed,
|
| 125 |
+
stratify=stratify_col
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Second split: separate train and validation
|
| 129 |
+
val_size_adjusted = val_size / (train_size + val_size)
|
| 130 |
+
stratify_col_train = train_val_df['label_id'] if stratify else None
|
| 131 |
+
train_df, val_df = train_test_split(
|
| 132 |
+
train_val_df,
|
| 133 |
+
test_size=val_size_adjusted,
|
| 134 |
+
random_state=seed,
|
| 135 |
+
stratify=stratify_col_train
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Log distribution
|
| 139 |
+
logging.info(f"Dataset splits - Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")
|
| 140 |
+
logging.info(f"Train label distribution:\n{train_df['label'].value_counts().sort_index()}")
|
| 141 |
+
|
| 142 |
+
# Create datasets
|
| 143 |
+
dataset_dict = DatasetDict({
|
| 144 |
+
'train': Dataset.from_pandas(train_df[['comment', 'label_id']], preserve_index=False),
|
| 145 |
+
'validation': Dataset.from_pandas(val_df[['comment', 'label_id']], preserve_index=False),
|
| 146 |
+
'test': Dataset.from_pandas(test_df[['comment', 'label_id']], preserve_index=False)
|
| 147 |
+
})
|
| 148 |
+
|
| 149 |
+
return dataset_dict
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def tokenize_function(examples, tokenizer, max_length: int = 512):
|
| 153 |
+
"""Tokenize the input text."""
|
| 154 |
+
return tokenizer(
|
| 155 |
+
examples['comment'],
|
| 156 |
+
padding='max_length',
|
| 157 |
+
truncation=True,
|
| 158 |
+
max_length=max_length
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def prepare_datasets_for_training(config_path: str = "config.yaml"):
|
| 163 |
+
"""
|
| 164 |
+
Complete pipeline to prepare datasets for training.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
Tuple of (tokenized_datasets, label2id, id2label, tokenizer)
|
| 168 |
+
"""
|
| 169 |
+
from transformers import AutoTokenizer
|
| 170 |
+
|
| 171 |
+
config = load_config(config_path)
|
| 172 |
+
|
| 173 |
+
# Load data
|
| 174 |
+
df = load_data(config['data']['data_path'])
|
| 175 |
+
|
| 176 |
+
# Create label mappings
|
| 177 |
+
labels = config['labels']
|
| 178 |
+
label2id, id2label = create_label_mapping(labels)
|
| 179 |
+
|
| 180 |
+
# Prepare dataset splits
|
| 181 |
+
stratify = config['data'].get('stratify', True)
|
| 182 |
+
dataset_dict = prepare_dataset(
|
| 183 |
+
df,
|
| 184 |
+
label2id,
|
| 185 |
+
train_size=config['data']['train_size'],
|
| 186 |
+
val_size=config['data']['val_size'],
|
| 187 |
+
test_size=config['data']['test_size'],
|
| 188 |
+
seed=config['training']['seed'],
|
| 189 |
+
stratify=stratify
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Load tokenizer
|
| 193 |
+
tokenizer = AutoTokenizer.from_pretrained(config['model']['name'])
|
| 194 |
+
|
| 195 |
+
# Tokenize datasets
|
| 196 |
+
tokenized_datasets = dataset_dict.map(
|
| 197 |
+
lambda x: tokenize_function(x, tokenizer, config['model']['max_length']),
|
| 198 |
+
batched=True,
|
| 199 |
+
remove_columns=['comment']
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Rename label_id to labels for training
|
| 203 |
+
tokenized_datasets = tokenized_datasets.rename_column('label_id', 'labels')
|
| 204 |
+
|
| 205 |
+
return tokenized_datasets, label2id, id2label, tokenizer
|
src/model.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model definition and utilities
|
| 3 |
+
"""
|
| 4 |
+
from transformers import AutoModelForSequenceClassification, AutoConfig
|
| 5 |
+
from typing import Dict, Optional
|
| 6 |
+
import logging
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_model(
|
| 12 |
+
model_name: str,
|
| 13 |
+
num_labels: int,
|
| 14 |
+
label2id: Dict[str, int],
|
| 15 |
+
id2label: Dict[int, str],
|
| 16 |
+
dropout: Optional[float] = None
|
| 17 |
+
):
|
| 18 |
+
"""
|
| 19 |
+
Create a sequence classification model with optional dropout configuration.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
model_name: Name of the pretrained model
|
| 23 |
+
num_labels: Number of classification labels
|
| 24 |
+
label2id: Mapping from label names to IDs
|
| 25 |
+
id2label: Mapping from IDs to label names
|
| 26 |
+
dropout: Optional dropout probability for classifier head
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Initialized model
|
| 30 |
+
"""
|
| 31 |
+
config = AutoConfig.from_pretrained(
|
| 32 |
+
model_name,
|
| 33 |
+
num_labels=num_labels,
|
| 34 |
+
label2id=label2id,
|
| 35 |
+
id2label=id2label
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Set dropout if provided
|
| 39 |
+
if dropout is not None:
|
| 40 |
+
if hasattr(config, 'hidden_dropout_prob'):
|
| 41 |
+
config.hidden_dropout_prob = dropout
|
| 42 |
+
if hasattr(config, 'attention_probs_dropout_prob'):
|
| 43 |
+
config.attention_probs_dropout_prob = dropout
|
| 44 |
+
if hasattr(config, 'classifier_dropout'):
|
| 45 |
+
config.classifier_dropout = dropout
|
| 46 |
+
logging.info(f"Set model dropout to {dropout}")
|
| 47 |
+
|
| 48 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 49 |
+
model_name,
|
| 50 |
+
config=config
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
return model
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def apply_class_weights(
|
| 57 |
+
model: nn.Module,
|
| 58 |
+
class_weights: Optional[list] = None
|
| 59 |
+
) -> Optional[nn.Module]:
|
| 60 |
+
"""
|
| 61 |
+
Apply class weights to the model's loss function.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
model: The model to modify
|
| 65 |
+
class_weights: List of weights for each class (must match num_labels)
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Model with modified loss function (if class_weights provided)
|
| 69 |
+
"""
|
| 70 |
+
if class_weights is not None:
|
| 71 |
+
weights_tensor = torch.tensor(class_weights, dtype=torch.float32)
|
| 72 |
+
# Note: This requires custom Trainer with weighted loss
|
| 73 |
+
logging.info(f"Class weights applied: {class_weights}")
|
| 74 |
+
return weights_tensor
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_model_size(model: nn.Module) -> float:
|
| 79 |
+
"""
|
| 80 |
+
Calculate model size in millions of parameters.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
model: PyTorch model
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Number of parameters in millions
|
| 87 |
+
"""
|
| 88 |
+
param_size = sum(p.numel() for p in model.parameters())
|
| 89 |
+
return param_size / 1e6
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def get_trainable_params(model: nn.Module) -> Dict[str, int]:
|
| 93 |
+
"""
|
| 94 |
+
Get count of trainable and non-trainable parameters.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
model: PyTorch model
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Dictionary with 'trainable' and 'total' parameter counts
|
| 101 |
+
"""
|
| 102 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 103 |
+
total = sum(p.numel() for p in model.parameters())
|
| 104 |
+
return {
|
| 105 |
+
'trainable': trainable,
|
| 106 |
+
'total': total,
|
| 107 |
+
'non_trainable': total - trainable
|
| 108 |
+
}
|
src/utils.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for training and evaluation
|
| 3 |
+
"""
|
| 4 |
+
import numpy as np
|
| 5 |
+
from sklearn.metrics import (
|
| 6 |
+
accuracy_score,
|
| 7 |
+
precision_recall_fscore_support,
|
| 8 |
+
confusion_matrix,
|
| 9 |
+
classification_report
|
| 10 |
+
)
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import seaborn as sns
|
| 13 |
+
from typing import Dict, Tuple, List, Optional
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def compute_metrics(eval_pred, id2label: Optional[Dict[int, str]] = None) -> Dict[str, float]:
|
| 18 |
+
"""
|
| 19 |
+
Compute comprehensive metrics for evaluation.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
eval_pred: Tuple of (predictions, labels)
|
| 23 |
+
id2label: Optional mapping from label IDs to label names for per-class metrics
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Dictionary of metrics including overall and per-class metrics
|
| 27 |
+
"""
|
| 28 |
+
predictions, labels = eval_pred
|
| 29 |
+
predictions = np.argmax(predictions, axis=1)
|
| 30 |
+
|
| 31 |
+
# Overall metrics
|
| 32 |
+
accuracy = accuracy_score(labels, predictions)
|
| 33 |
+
|
| 34 |
+
# Weighted metrics (accounts for class imbalance)
|
| 35 |
+
precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
|
| 36 |
+
labels,
|
| 37 |
+
predictions,
|
| 38 |
+
average='weighted',
|
| 39 |
+
zero_division=0
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Macro-averaged metrics (treats all classes equally)
|
| 43 |
+
precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
|
| 44 |
+
labels,
|
| 45 |
+
predictions,
|
| 46 |
+
average='macro',
|
| 47 |
+
zero_division=0
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Micro-averaged metrics (aggregates contributions of all classes)
|
| 51 |
+
precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
|
| 52 |
+
labels,
|
| 53 |
+
predictions,
|
| 54 |
+
average='micro',
|
| 55 |
+
zero_division=0
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
metrics = {
|
| 59 |
+
'accuracy': accuracy,
|
| 60 |
+
'precision_weighted': precision_weighted,
|
| 61 |
+
'recall_weighted': recall_weighted,
|
| 62 |
+
'f1_weighted': f1_weighted,
|
| 63 |
+
'precision_macro': precision_macro,
|
| 64 |
+
'recall_macro': recall_macro,
|
| 65 |
+
'f1_macro': f1_macro,
|
| 66 |
+
'precision_micro': precision_micro,
|
| 67 |
+
'recall_micro': recall_micro,
|
| 68 |
+
'f1_micro': f1_micro,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
# Per-class metrics if label mapping is provided
|
| 72 |
+
if id2label is not None:
|
| 73 |
+
num_classes = len(id2label)
|
| 74 |
+
precision_per_class, recall_per_class, f1_per_class, support = precision_recall_fscore_support(
|
| 75 |
+
labels,
|
| 76 |
+
predictions,
|
| 77 |
+
labels=list(range(num_classes)),
|
| 78 |
+
average=None,
|
| 79 |
+
zero_division=0
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
for i in range(num_classes):
|
| 83 |
+
label_name = id2label[i]
|
| 84 |
+
metrics[f'precision_{label_name}'] = float(precision_per_class[i])
|
| 85 |
+
metrics[f'recall_{label_name}'] = float(recall_per_class[i])
|
| 86 |
+
metrics[f'f1_{label_name}'] = float(f1_per_class[i])
|
| 87 |
+
metrics[f'support_{label_name}'] = int(support[i])
|
| 88 |
+
|
| 89 |
+
return metrics
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def compute_metrics_factory(id2label: Optional[Dict[int, str]] = None):
|
| 93 |
+
"""
|
| 94 |
+
Factory function to create compute_metrics with label mapping.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
id2label: Mapping from label IDs to label names
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Function compatible with HuggingFace Trainer
|
| 101 |
+
"""
|
| 102 |
+
def compute_metrics_fn(eval_pred):
|
| 103 |
+
return compute_metrics(eval_pred, id2label)
|
| 104 |
+
|
| 105 |
+
return compute_metrics_fn
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def plot_confusion_matrix(
|
| 109 |
+
y_true: np.ndarray,
|
| 110 |
+
y_pred: np.ndarray,
|
| 111 |
+
labels: List[str],
|
| 112 |
+
save_path: str = "confusion_matrix.png",
|
| 113 |
+
normalize: bool = False,
|
| 114 |
+
figsize: Tuple[int, int] = (10, 8)
|
| 115 |
+
) -> None:
|
| 116 |
+
"""
|
| 117 |
+
Plot and save confusion matrix with optional normalization.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
y_true: True labels
|
| 121 |
+
y_pred: Predicted labels
|
| 122 |
+
labels: List of label names
|
| 123 |
+
save_path: Path to save the plot
|
| 124 |
+
normalize: If True, normalize confusion matrix to percentages
|
| 125 |
+
figsize: Figure size (width, height)
|
| 126 |
+
"""
|
| 127 |
+
cm = confusion_matrix(y_true, y_pred, labels=list(range(len(labels))))
|
| 128 |
+
|
| 129 |
+
if normalize:
|
| 130 |
+
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
| 131 |
+
fmt = '.2f'
|
| 132 |
+
title = 'Normalized Confusion Matrix'
|
| 133 |
+
else:
|
| 134 |
+
fmt = 'd'
|
| 135 |
+
title = 'Confusion Matrix'
|
| 136 |
+
|
| 137 |
+
plt.figure(figsize=figsize)
|
| 138 |
+
sns.heatmap(
|
| 139 |
+
cm,
|
| 140 |
+
annot=True,
|
| 141 |
+
fmt=fmt,
|
| 142 |
+
cmap='Blues',
|
| 143 |
+
xticklabels=labels,
|
| 144 |
+
yticklabels=labels,
|
| 145 |
+
cbar_kws={'label': 'Percentage' if normalize else 'Count'}
|
| 146 |
+
)
|
| 147 |
+
plt.title(title, fontsize=14, fontweight='bold')
|
| 148 |
+
plt.ylabel('True Label', fontsize=12)
|
| 149 |
+
plt.xlabel('Predicted Label', fontsize=12)
|
| 150 |
+
plt.tight_layout()
|
| 151 |
+
|
| 152 |
+
# Create directory if it doesn't exist
|
| 153 |
+
os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)
|
| 154 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 155 |
+
plt.close()
|
| 156 |
+
|
| 157 |
+
print(f"Confusion matrix saved to {save_path}")
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def print_classification_report(
|
| 161 |
+
y_true: np.ndarray,
|
| 162 |
+
y_pred: np.ndarray,
|
| 163 |
+
labels: List[str],
|
| 164 |
+
output_dict: bool = False
|
| 165 |
+
) -> Optional[Dict]:
|
| 166 |
+
"""
|
| 167 |
+
Print detailed classification report.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
y_true: True labels
|
| 171 |
+
y_pred: Predicted labels
|
| 172 |
+
labels: List of label names
|
| 173 |
+
output_dict: If True, return report as dictionary instead of printing
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
Classification report as dictionary if output_dict=True, else None
|
| 177 |
+
"""
|
| 178 |
+
report = classification_report(
|
| 179 |
+
y_true,
|
| 180 |
+
y_pred,
|
| 181 |
+
target_names=labels,
|
| 182 |
+
digits=4,
|
| 183 |
+
output_dict=output_dict,
|
| 184 |
+
zero_division=0
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
if output_dict:
|
| 188 |
+
return report
|
| 189 |
+
|
| 190 |
+
print("\nClassification Report:")
|
| 191 |
+
print("=" * 60)
|
| 192 |
+
print(report)
|
| 193 |
+
return None
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def plot_training_curves(
|
| 197 |
+
train_losses: List[float],
|
| 198 |
+
eval_losses: List[float],
|
| 199 |
+
eval_metrics: Dict[str, List[float]],
|
| 200 |
+
save_path: str = "./results/training_curves.png"
|
| 201 |
+
) -> None:
|
| 202 |
+
"""
|
| 203 |
+
Plot training and evaluation curves.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
train_losses: List of training losses per step/epoch
|
| 207 |
+
eval_losses: List of evaluation losses per step/epoch
|
| 208 |
+
eval_metrics: Dictionary of metric names to lists of values
|
| 209 |
+
save_path: Path to save the plot
|
| 210 |
+
"""
|
| 211 |
+
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
|
| 212 |
+
|
| 213 |
+
# Loss curves
|
| 214 |
+
axes[0, 0].plot(train_losses, label='Train Loss', color='blue')
|
| 215 |
+
axes[0, 0].plot(eval_losses, label='Eval Loss', color='red')
|
| 216 |
+
axes[0, 0].set_xlabel('Step/Epoch')
|
| 217 |
+
axes[0, 0].set_ylabel('Loss')
|
| 218 |
+
axes[0, 0].set_title('Training and Validation Loss')
|
| 219 |
+
axes[0, 0].legend()
|
| 220 |
+
axes[0, 0].grid(True, alpha=0.3)
|
| 221 |
+
|
| 222 |
+
# Accuracy
|
| 223 |
+
if 'accuracy' in eval_metrics:
|
| 224 |
+
axes[0, 1].plot(eval_metrics['accuracy'], label='Accuracy', color='green')
|
| 225 |
+
axes[0, 1].set_xlabel('Step/Epoch')
|
| 226 |
+
axes[0, 1].set_ylabel('Accuracy')
|
| 227 |
+
axes[0, 1].set_title('Validation Accuracy')
|
| 228 |
+
axes[0, 1].legend()
|
| 229 |
+
axes[0, 1].grid(True, alpha=0.3)
|
| 230 |
+
|
| 231 |
+
# F1 Score
|
| 232 |
+
if 'f1_weighted' in eval_metrics:
|
| 233 |
+
axes[1, 0].plot(eval_metrics['f1_weighted'], label='F1 (weighted)', color='purple')
|
| 234 |
+
axes[1, 0].set_xlabel('Step/Epoch')
|
| 235 |
+
axes[1, 0].set_ylabel('F1 Score')
|
| 236 |
+
axes[1, 0].set_title('Validation F1 Score')
|
| 237 |
+
axes[1, 0].legend()
|
| 238 |
+
axes[1, 0].grid(True, alpha=0.3)
|
| 239 |
+
|
| 240 |
+
# Precision and Recall
|
| 241 |
+
if 'precision_weighted' in eval_metrics and 'recall_weighted' in eval_metrics:
|
| 242 |
+
axes[1, 1].plot(eval_metrics['precision_weighted'], label='Precision', color='orange')
|
| 243 |
+
axes[1, 1].plot(eval_metrics['recall_weighted'], label='Recall', color='cyan')
|
| 244 |
+
axes[1, 1].set_xlabel('Step/Epoch')
|
| 245 |
+
axes[1, 1].set_ylabel('Score')
|
| 246 |
+
axes[1, 1].set_title('Validation Precision and Recall')
|
| 247 |
+
axes[1, 1].legend()
|
| 248 |
+
axes[1, 1].grid(True, alpha=0.3)
|
| 249 |
+
|
| 250 |
+
plt.tight_layout()
|
| 251 |
+
os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)
|
| 252 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 253 |
+
plt.close()
|
| 254 |
+
|
| 255 |
+
print(f"Training curves saved to {save_path}")
|
src/validation.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Validation utilities for model and data validation
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import yaml
|
| 6 |
+
from typing import Dict, List, Optional
|
| 7 |
+
import logging
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def validate_config(config: Dict) -> List[str]:
|
| 12 |
+
"""
|
| 13 |
+
Validate configuration file for common issues.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
config: Configuration dictionary
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
List of validation error messages (empty if valid)
|
| 20 |
+
"""
|
| 21 |
+
errors = []
|
| 22 |
+
|
| 23 |
+
# Check required sections
|
| 24 |
+
required_sections = ['model', 'training', 'data', 'labels']
|
| 25 |
+
for section in required_sections:
|
| 26 |
+
if section not in config:
|
| 27 |
+
errors.append(f"Missing required section: {section}")
|
| 28 |
+
|
| 29 |
+
if errors:
|
| 30 |
+
return errors
|
| 31 |
+
|
| 32 |
+
# Validate model section
|
| 33 |
+
if 'name' not in config['model']:
|
| 34 |
+
errors.append("model.name is required")
|
| 35 |
+
if 'num_labels' not in config['model']:
|
| 36 |
+
errors.append("model.num_labels is required")
|
| 37 |
+
elif config['model']['num_labels'] != len(config.get('labels', [])):
|
| 38 |
+
errors.append(f"model.num_labels ({config['model']['num_labels']}) doesn't match number of labels ({len(config['labels'])})")
|
| 39 |
+
|
| 40 |
+
# Validate training section
|
| 41 |
+
training = config['training']
|
| 42 |
+
if 'num_train_epochs' in training and training['num_train_epochs'] <= 0:
|
| 43 |
+
errors.append("training.num_train_epochs must be positive")
|
| 44 |
+
if 'learning_rate' in training and training['learning_rate'] <= 0:
|
| 45 |
+
errors.append("training.learning_rate must be positive")
|
| 46 |
+
if 'per_device_train_batch_size' in training and training['per_device_train_batch_size'] <= 0:
|
| 47 |
+
errors.append("training.per_device_train_batch_size must be positive")
|
| 48 |
+
|
| 49 |
+
# Validate data section
|
| 50 |
+
data = config['data']
|
| 51 |
+
if 'data_path' in data and not os.path.exists(data['data_path']):
|
| 52 |
+
errors.append(f"Data file not found: {data['data_path']}")
|
| 53 |
+
|
| 54 |
+
train_size = data.get('train_size', 0)
|
| 55 |
+
val_size = data.get('val_size', 0)
|
| 56 |
+
test_size = data.get('test_size', 0)
|
| 57 |
+
total = train_size + val_size + test_size
|
| 58 |
+
if abs(total - 1.0) > 1e-6:
|
| 59 |
+
errors.append(f"Data split sizes must sum to 1.0, got {total}")
|
| 60 |
+
|
| 61 |
+
# Validate labels
|
| 62 |
+
if 'labels' not in config or not config['labels']:
|
| 63 |
+
errors.append("labels section is required and cannot be empty")
|
| 64 |
+
elif len(set(config['labels'])) != len(config['labels']):
|
| 65 |
+
errors.append("labels must be unique")
|
| 66 |
+
|
| 67 |
+
return errors
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def validate_model_path(model_path: str) -> bool:
|
| 71 |
+
"""
|
| 72 |
+
Validate that model path exists and contains required files.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
model_path: Path to model directory
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
True if valid, False otherwise
|
| 79 |
+
"""
|
| 80 |
+
if not os.path.exists(model_path):
|
| 81 |
+
logging.error(f"Model path does not exist: {model_path}")
|
| 82 |
+
return False
|
| 83 |
+
|
| 84 |
+
required_files = ['config.json']
|
| 85 |
+
for file in required_files:
|
| 86 |
+
file_path = os.path.join(model_path, file)
|
| 87 |
+
if not os.path.exists(file_path):
|
| 88 |
+
logging.error(f"Required file missing: {file_path}")
|
| 89 |
+
return False
|
| 90 |
+
|
| 91 |
+
return True
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def validate_data_file(data_path: str, required_columns: List[str] = None) -> List[str]:
|
| 95 |
+
"""
|
| 96 |
+
Validate data file format and content.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
data_path: Path to data file
|
| 100 |
+
required_columns: List of required column names
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
List of validation error messages (empty if valid)
|
| 104 |
+
"""
|
| 105 |
+
errors = []
|
| 106 |
+
|
| 107 |
+
if required_columns is None:
|
| 108 |
+
required_columns = ['comment', 'label']
|
| 109 |
+
|
| 110 |
+
if not os.path.exists(data_path):
|
| 111 |
+
errors.append(f"Data file not found: {data_path}")
|
| 112 |
+
return errors
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
import pandas as pd
|
| 116 |
+
df = pd.read_csv(data_path)
|
| 117 |
+
|
| 118 |
+
# Check required columns
|
| 119 |
+
missing_columns = [col for col in required_columns if col not in df.columns]
|
| 120 |
+
if missing_columns:
|
| 121 |
+
errors.append(f"Missing required columns: {missing_columns}")
|
| 122 |
+
|
| 123 |
+
# Check for empty dataframe
|
| 124 |
+
if len(df) == 0:
|
| 125 |
+
errors.append("Data file is empty")
|
| 126 |
+
|
| 127 |
+
# Check for missing values in required columns
|
| 128 |
+
if 'comment' in df.columns:
|
| 129 |
+
empty_comments = df['comment'].isna().sum() + (df['comment'].str.strip().str.len() == 0).sum()
|
| 130 |
+
if empty_comments > 0:
|
| 131 |
+
errors.append(f"Found {empty_comments} empty comments")
|
| 132 |
+
|
| 133 |
+
if 'label' in df.columns:
|
| 134 |
+
missing_labels = df['label'].isna().sum()
|
| 135 |
+
if missing_labels > 0:
|
| 136 |
+
errors.append(f"Found {missing_labels} missing labels")
|
| 137 |
+
|
| 138 |
+
except Exception as e:
|
| 139 |
+
errors.append(f"Error reading data file: {str(e)}")
|
| 140 |
+
|
| 141 |
+
return errors
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def validate_config_file(config_path: str) -> bool:
|
| 145 |
+
"""
|
| 146 |
+
Validate configuration file.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
config_path: Path to configuration file
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
True if valid, False otherwise
|
| 153 |
+
"""
|
| 154 |
+
if not os.path.exists(config_path):
|
| 155 |
+
logging.error(f"Config file not found: {config_path}")
|
| 156 |
+
return False
|
| 157 |
+
|
| 158 |
+
try:
|
| 159 |
+
with open(config_path, 'r') as f:
|
| 160 |
+
config = yaml.safe_load(f)
|
| 161 |
+
|
| 162 |
+
errors = validate_config(config)
|
| 163 |
+
if errors:
|
| 164 |
+
logging.error("Configuration validation errors:")
|
| 165 |
+
for error in errors:
|
| 166 |
+
logging.error(f" - {error}")
|
| 167 |
+
return False
|
| 168 |
+
|
| 169 |
+
logging.info("Configuration file is valid")
|
| 170 |
+
return True
|
| 171 |
+
|
| 172 |
+
except Exception as e:
|
| 173 |
+
logging.error(f"Error reading config file: {str(e)}")
|
| 174 |
+
return False
|