|
|
""" |
|
|
Data loader utilities for Code Comment Quality Classifier |
|
|
""" |
|
|
import pandas as pd |
|
|
from datasets import Dataset, DatasetDict |
|
|
from sklearn.model_selection import train_test_split |
|
|
from typing import Tuple, Dict, List, Optional |
|
|
import yaml |
|
|
import logging |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
def load_config(config_path: str = "config.yaml") -> dict: |
|
|
"""Load configuration from YAML file.""" |
|
|
with open(config_path, 'r') as f: |
|
|
config = yaml.safe_load(f) |
|
|
return config |
|
|
|
|
|
|
|
|
def load_data(data_path: str) -> pd.DataFrame: |
|
|
""" |
|
|
Load data from CSV file with validation. |
|
|
|
|
|
Expected format: |
|
|
- comment: str (the code comment text) |
|
|
- label: str (excellent, helpful, unclear, or outdated) |
|
|
|
|
|
Args: |
|
|
data_path: Path to the CSV file |
|
|
|
|
|
Returns: |
|
|
DataFrame with validated data |
|
|
|
|
|
Raises: |
|
|
FileNotFoundError: If data file doesn't exist |
|
|
ValueError: If data format is invalid |
|
|
""" |
|
|
if not os.path.exists(data_path): |
|
|
raise FileNotFoundError(f"Data file not found: {data_path}") |
|
|
|
|
|
df = pd.read_csv(data_path) |
|
|
|
|
|
|
|
|
required_columns = ['comment', 'label'] |
|
|
missing_columns = [col for col in required_columns if col not in df.columns] |
|
|
if missing_columns: |
|
|
raise ValueError(f"Missing required columns: {missing_columns}") |
|
|
|
|
|
|
|
|
initial_len = len(df) |
|
|
df = df.dropna(subset=required_columns) |
|
|
if len(df) < initial_len: |
|
|
logging.warning(f"Removed {initial_len - len(df)} rows with missing values") |
|
|
|
|
|
|
|
|
df = df[df['comment'].str.strip().str.len() > 0] |
|
|
|
|
|
|
|
|
if df['label'].isna().any(): |
|
|
logging.warning("Found NaN labels, removing those rows") |
|
|
df = df.dropna(subset=['label']) |
|
|
|
|
|
logging.info(f"Loaded {len(df)} samples from {data_path}") |
|
|
return df |
|
|
|
|
|
|
|
|
def create_label_mapping(labels: list) -> Tuple[Dict[str, int], Dict[int, str]]: |
|
|
"""Create bidirectional label mapping.""" |
|
|
label2id = {label: idx for idx, label in enumerate(labels)} |
|
|
id2label = {idx: label for idx, label in enumerate(labels)} |
|
|
return label2id, id2label |
|
|
|
|
|
|
|
|
def prepare_dataset( |
|
|
df: pd.DataFrame, |
|
|
label2id: Dict[str, int], |
|
|
train_size: float = 0.8, |
|
|
val_size: float = 0.1, |
|
|
test_size: float = 0.1, |
|
|
seed: int = 42, |
|
|
stratify: bool = True |
|
|
) -> DatasetDict: |
|
|
""" |
|
|
Prepare dataset splits for training. |
|
|
|
|
|
Args: |
|
|
df: DataFrame with 'comment' and 'label' columns |
|
|
label2id: Mapping from label names to IDs |
|
|
train_size: Proportion of training data |
|
|
val_size: Proportion of validation data |
|
|
test_size: Proportion of test data |
|
|
seed: Random seed for reproducibility |
|
|
stratify: Whether to maintain class distribution in splits |
|
|
|
|
|
Returns: |
|
|
DatasetDict with train, validation, and test splits |
|
|
""" |
|
|
|
|
|
invalid_labels = set(df['label'].unique()) - set(label2id.keys()) |
|
|
if invalid_labels: |
|
|
raise ValueError(f"Found invalid labels: {invalid_labels}. Expected: {list(label2id.keys())}") |
|
|
|
|
|
|
|
|
df['label_id'] = df['label'].map(label2id) |
|
|
|
|
|
|
|
|
if df['label_id'].isna().any(): |
|
|
missing_labels = df[df['label_id'].isna()]['label'].unique() |
|
|
raise ValueError(f"Labels not found in label2id mapping: {missing_labels}") |
|
|
|
|
|
|
|
|
total_size = train_size + val_size + test_size |
|
|
if abs(total_size - 1.0) > 1e-6: |
|
|
raise ValueError(f"Split sizes must sum to 1.0, got {total_size}") |
|
|
|
|
|
|
|
|
stratify_col = df['label_id'] if stratify else None |
|
|
|
|
|
|
|
|
train_val_df, test_df = train_test_split( |
|
|
df, |
|
|
test_size=test_size, |
|
|
random_state=seed, |
|
|
stratify=stratify_col |
|
|
) |
|
|
|
|
|
|
|
|
val_size_adjusted = val_size / (train_size + val_size) |
|
|
stratify_col_train = train_val_df['label_id'] if stratify else None |
|
|
train_df, val_df = train_test_split( |
|
|
train_val_df, |
|
|
test_size=val_size_adjusted, |
|
|
random_state=seed, |
|
|
stratify=stratify_col_train |
|
|
) |
|
|
|
|
|
|
|
|
logging.info(f"Dataset splits - Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}") |
|
|
logging.info(f"Train label distribution:\n{train_df['label'].value_counts().sort_index()}") |
|
|
|
|
|
|
|
|
dataset_dict = DatasetDict({ |
|
|
'train': Dataset.from_pandas(train_df[['comment', 'label_id']], preserve_index=False), |
|
|
'validation': Dataset.from_pandas(val_df[['comment', 'label_id']], preserve_index=False), |
|
|
'test': Dataset.from_pandas(test_df[['comment', 'label_id']], preserve_index=False) |
|
|
}) |
|
|
|
|
|
return dataset_dict |
|
|
|
|
|
|
|
|
def tokenize_function(examples, tokenizer, max_length: int = 512): |
|
|
"""Tokenize the input text.""" |
|
|
return tokenizer( |
|
|
examples['comment'], |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
max_length=max_length |
|
|
) |
|
|
|
|
|
|
|
|
def prepare_datasets_for_training(config_path: str = "config.yaml"): |
|
|
""" |
|
|
Complete pipeline to prepare datasets for training. |
|
|
|
|
|
Returns: |
|
|
Tuple of (tokenized_datasets, label2id, id2label, tokenizer) |
|
|
""" |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
config = load_config(config_path) |
|
|
|
|
|
|
|
|
df = load_data(config['data']['data_path']) |
|
|
|
|
|
|
|
|
labels = config['labels'] |
|
|
label2id, id2label = create_label_mapping(labels) |
|
|
|
|
|
|
|
|
stratify = config['data'].get('stratify', True) |
|
|
dataset_dict = prepare_dataset( |
|
|
df, |
|
|
label2id, |
|
|
train_size=config['data']['train_size'], |
|
|
val_size=config['data']['val_size'], |
|
|
test_size=config['data']['test_size'], |
|
|
seed=config['training']['seed'], |
|
|
stratify=stratify |
|
|
) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(config['model']['name']) |
|
|
|
|
|
|
|
|
tokenized_datasets = dataset_dict.map( |
|
|
lambda x: tokenize_function(x, tokenizer, config['model']['max_length']), |
|
|
batched=True, |
|
|
remove_columns=['comment'] |
|
|
) |
|
|
|
|
|
|
|
|
tokenized_datasets = tokenized_datasets.rename_column('label_id', 'labels') |
|
|
|
|
|
return tokenized_datasets, label2id, id2label, tokenizer |
|
|
|