|
|
""" |
|
|
Validation utilities for model and data validation |
|
|
""" |
|
|
import os |
|
|
import yaml |
|
|
from typing import Dict, List, Optional |
|
|
import logging |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
def validate_config(config: Dict) -> List[str]: |
|
|
""" |
|
|
Validate configuration file for common issues. |
|
|
|
|
|
Args: |
|
|
config: Configuration dictionary |
|
|
|
|
|
Returns: |
|
|
List of validation error messages (empty if valid) |
|
|
""" |
|
|
errors = [] |
|
|
|
|
|
|
|
|
required_sections = ['model', 'training', 'data', 'labels'] |
|
|
for section in required_sections: |
|
|
if section not in config: |
|
|
errors.append(f"Missing required section: {section}") |
|
|
|
|
|
if errors: |
|
|
return errors |
|
|
|
|
|
|
|
|
if 'name' not in config['model']: |
|
|
errors.append("model.name is required") |
|
|
if 'num_labels' not in config['model']: |
|
|
errors.append("model.num_labels is required") |
|
|
elif config['model']['num_labels'] != len(config.get('labels', [])): |
|
|
errors.append(f"model.num_labels ({config['model']['num_labels']}) doesn't match number of labels ({len(config['labels'])})") |
|
|
|
|
|
|
|
|
training = config['training'] |
|
|
if 'num_train_epochs' in training and training['num_train_epochs'] <= 0: |
|
|
errors.append("training.num_train_epochs must be positive") |
|
|
if 'learning_rate' in training and training['learning_rate'] <= 0: |
|
|
errors.append("training.learning_rate must be positive") |
|
|
if 'per_device_train_batch_size' in training and training['per_device_train_batch_size'] <= 0: |
|
|
errors.append("training.per_device_train_batch_size must be positive") |
|
|
|
|
|
|
|
|
data = config['data'] |
|
|
if 'data_path' in data and not os.path.exists(data['data_path']): |
|
|
errors.append(f"Data file not found: {data['data_path']}") |
|
|
|
|
|
train_size = data.get('train_size', 0) |
|
|
val_size = data.get('val_size', 0) |
|
|
test_size = data.get('test_size', 0) |
|
|
total = train_size + val_size + test_size |
|
|
if abs(total - 1.0) > 1e-6: |
|
|
errors.append(f"Data split sizes must sum to 1.0, got {total}") |
|
|
|
|
|
|
|
|
if 'labels' not in config or not config['labels']: |
|
|
errors.append("labels section is required and cannot be empty") |
|
|
elif len(set(config['labels'])) != len(config['labels']): |
|
|
errors.append("labels must be unique") |
|
|
|
|
|
return errors |
|
|
|
|
|
|
|
|
def validate_model_path(model_path: str) -> bool: |
|
|
""" |
|
|
Validate that model path exists and contains required files. |
|
|
|
|
|
Args: |
|
|
model_path: Path to model directory |
|
|
|
|
|
Returns: |
|
|
True if valid, False otherwise |
|
|
""" |
|
|
if not os.path.exists(model_path): |
|
|
logging.error(f"Model path does not exist: {model_path}") |
|
|
return False |
|
|
|
|
|
required_files = ['config.json'] |
|
|
for file in required_files: |
|
|
file_path = os.path.join(model_path, file) |
|
|
if not os.path.exists(file_path): |
|
|
logging.error(f"Required file missing: {file_path}") |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def validate_data_file(data_path: str, required_columns: List[str] = None) -> List[str]: |
|
|
""" |
|
|
Validate data file format and content. |
|
|
|
|
|
Args: |
|
|
data_path: Path to data file |
|
|
required_columns: List of required column names |
|
|
|
|
|
Returns: |
|
|
List of validation error messages (empty if valid) |
|
|
""" |
|
|
errors = [] |
|
|
|
|
|
if required_columns is None: |
|
|
required_columns = ['comment', 'label'] |
|
|
|
|
|
if not os.path.exists(data_path): |
|
|
errors.append(f"Data file not found: {data_path}") |
|
|
return errors |
|
|
|
|
|
try: |
|
|
import pandas as pd |
|
|
df = pd.read_csv(data_path) |
|
|
|
|
|
|
|
|
missing_columns = [col for col in required_columns if col not in df.columns] |
|
|
if missing_columns: |
|
|
errors.append(f"Missing required columns: {missing_columns}") |
|
|
|
|
|
|
|
|
if len(df) == 0: |
|
|
errors.append("Data file is empty") |
|
|
|
|
|
|
|
|
if 'comment' in df.columns: |
|
|
empty_comments = df['comment'].isna().sum() + (df['comment'].str.strip().str.len() == 0).sum() |
|
|
if empty_comments > 0: |
|
|
errors.append(f"Found {empty_comments} empty comments") |
|
|
|
|
|
if 'label' in df.columns: |
|
|
missing_labels = df['label'].isna().sum() |
|
|
if missing_labels > 0: |
|
|
errors.append(f"Found {missing_labels} missing labels") |
|
|
|
|
|
except Exception as e: |
|
|
errors.append(f"Error reading data file: {str(e)}") |
|
|
|
|
|
return errors |
|
|
|
|
|
|
|
|
def validate_config_file(config_path: str) -> bool: |
|
|
""" |
|
|
Validate configuration file. |
|
|
|
|
|
Args: |
|
|
config_path: Path to configuration file |
|
|
|
|
|
Returns: |
|
|
True if valid, False otherwise |
|
|
""" |
|
|
if not os.path.exists(config_path): |
|
|
logging.error(f"Config file not found: {config_path}") |
|
|
return False |
|
|
|
|
|
try: |
|
|
with open(config_path, 'r') as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
errors = validate_config(config) |
|
|
if errors: |
|
|
logging.error("Configuration validation errors:") |
|
|
for error in errors: |
|
|
logging.error(f" - {error}") |
|
|
return False |
|
|
|
|
|
logging.info("Configuration file is valid") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Error reading config file: {str(e)}") |
|
|
return False |
|
|
|