File size: 4,021 Bytes
ec8f374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Dataset Builder Module

Handles train/test splitting and dataset creation.
"""

import random
from typing import List, Dict, Any, Tuple, Optional


class DatasetBuilder:
    """Build and split datasets for training."""

    def __init__(self, seed: int = 42):
        """
        Initialize dataset builder.

        Args:
            seed: Random seed for reproducibility
        """
        self.seed = seed
        random.seed(seed)

    def train_test_split(
        self,
        data: List[Dict[str, Any]],
        train_ratio: float = 0.8,
        val_ratio: float = 0.1,
        test_ratio: float = 0.1,
        shuffle: bool = True
    ) -> Tuple[List[Dict], List[Dict], List[Dict]]:
        """
        Split data into train/validation/test sets.

        Args:
            data: List of data examples
            train_ratio: Fraction for training
            val_ratio: Fraction for validation
            test_ratio: Fraction for testing
            shuffle: Whether to shuffle data

        Returns:
            Tuple of (train_data, val_data, test_data)
        """
        # Validate ratios
        total = train_ratio + val_ratio + test_ratio
        if abs(total - 1.0) > 0.01:
            raise ValueError(f"Ratios must sum to 1.0, got {total}")

        # Shuffle if requested
        data_copy = data.copy()
        if shuffle:
            random.shuffle(data_copy)

        # Calculate split indices
        n = len(data_copy)
        train_end = int(n * train_ratio)
        val_end = train_end + int(n * val_ratio)

        # Split
        train_data = data_copy[:train_end]
        val_data = data_copy[train_end:val_end]
        test_data = data_copy[val_end:]

        return train_data, val_data, test_data

    def create_balanced_split(
        self,
        data: List[Dict[str, Any]],
        category_key: str,
        train_ratio: float = 0.8
    ) -> Tuple[List[Dict], List[Dict]]:
        """
        Create balanced train/test split by category.

        Args:
            data: List of data examples
            category_key: Key for category field
            train_ratio: Fraction for training

        Returns:
            Tuple of (train_data, test_data)
        """
        # Group by category
        categories = {}
        for example in data:
            cat = example.get(category_key, "unknown")
            if cat not in categories:
                categories[cat] = []
            categories[cat].append(example)

        # Split each category
        train_data = []
        test_data = []

        for cat, examples in categories.items():
            random.shuffle(examples)
            split_idx = int(len(examples) * train_ratio)
            train_data.extend(examples[:split_idx])
            test_data.extend(examples[split_idx:])

        # Shuffle final datasets
        random.shuffle(train_data)
        random.shuffle(test_data)

        return train_data, test_data

    def save_split(
        self,
        train_data: List[Dict],
        val_data: List[Dict],
        test_data: List[Dict],
        output_dir: str
    ) -> None:
        """
        Save split datasets to files.

        Args:
            train_data: Training data
            val_data: Validation data
            test_data: Test data
            output_dir: Output directory
        """
        import json
        from pathlib import Path

        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)

        # Save each split
        with open(output_path / "train.json", 'w') as f:
            json.dump(train_data, f, indent=2)

        with open(output_path / "val.json", 'w') as f:
            json.dump(val_data, f, indent=2)

        with open(output_path / "test.json", 'w') as f:
            json.dump(test_data, f, indent=2)

        print(f"✅ Datasets saved to {output_dir}")
        print(f"  Train: {len(train_data)} examples")
        print(f"  Val: {len(val_data)} examples")
        print(f"  Test: {len(test_data)} examples")