Spaces:
Runtime error
Runtime error
File size: 4,021 Bytes
ec8f374 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
"""
Dataset Builder Module
Handles train/test splitting and dataset creation.
"""
import random
from typing import List, Dict, Any, Tuple, Optional
class DatasetBuilder:
"""Build and split datasets for training."""
def __init__(self, seed: int = 42):
"""
Initialize dataset builder.
Args:
seed: Random seed for reproducibility
"""
self.seed = seed
random.seed(seed)
def train_test_split(
self,
data: List[Dict[str, Any]],
train_ratio: float = 0.8,
val_ratio: float = 0.1,
test_ratio: float = 0.1,
shuffle: bool = True
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
"""
Split data into train/validation/test sets.
Args:
data: List of data examples
train_ratio: Fraction for training
val_ratio: Fraction for validation
test_ratio: Fraction for testing
shuffle: Whether to shuffle data
Returns:
Tuple of (train_data, val_data, test_data)
"""
# Validate ratios
total = train_ratio + val_ratio + test_ratio
if abs(total - 1.0) > 0.01:
raise ValueError(f"Ratios must sum to 1.0, got {total}")
# Shuffle if requested
data_copy = data.copy()
if shuffle:
random.shuffle(data_copy)
# Calculate split indices
n = len(data_copy)
train_end = int(n * train_ratio)
val_end = train_end + int(n * val_ratio)
# Split
train_data = data_copy[:train_end]
val_data = data_copy[train_end:val_end]
test_data = data_copy[val_end:]
return train_data, val_data, test_data
def create_balanced_split(
self,
data: List[Dict[str, Any]],
category_key: str,
train_ratio: float = 0.8
) -> Tuple[List[Dict], List[Dict]]:
"""
Create balanced train/test split by category.
Args:
data: List of data examples
category_key: Key for category field
train_ratio: Fraction for training
Returns:
Tuple of (train_data, test_data)
"""
# Group by category
categories = {}
for example in data:
cat = example.get(category_key, "unknown")
if cat not in categories:
categories[cat] = []
categories[cat].append(example)
# Split each category
train_data = []
test_data = []
for cat, examples in categories.items():
random.shuffle(examples)
split_idx = int(len(examples) * train_ratio)
train_data.extend(examples[:split_idx])
test_data.extend(examples[split_idx:])
# Shuffle final datasets
random.shuffle(train_data)
random.shuffle(test_data)
return train_data, test_data
def save_split(
self,
train_data: List[Dict],
val_data: List[Dict],
test_data: List[Dict],
output_dir: str
) -> None:
"""
Save split datasets to files.
Args:
train_data: Training data
val_data: Validation data
test_data: Test data
output_dir: Output directory
"""
import json
from pathlib import Path
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Save each split
with open(output_path / "train.json", 'w') as f:
json.dump(train_data, f, indent=2)
with open(output_path / "val.json", 'w') as f:
json.dump(val_data, f, indent=2)
with open(output_path / "test.json", 'w') as f:
json.dump(test_data, f, indent=2)
print(f"✅ Datasets saved to {output_dir}")
print(f" Train: {len(train_data)} examples")
print(f" Val: {len(val_data)} examples")
print(f" Test: {len(test_data)} examples")
|