File size: 3,200 Bytes
13c35e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import random

def generate_v1_data():
    """Generates all exhaustive single-digit math problems."""
    data = []
    
    # Operators and their functions
    ops = {'+': lambda a, b: a + b, 
           '-': lambda a, b: a - b, 
           '*': lambda a, b: a * b, 
           '/': lambda a, b: a / b}
    
    # Iterate through all single-digit pairs (0-9)
    for a in range(10):
        for b in range(10):
            for op_char, op_func in ops.items():
                
                # Check for constraints: Single-Digit Answer (0-9) & Validity
                
                if op_char == '+':
                    result = op_func(a, b)
                    # Constraint: Sum must be a single digit (<= 9)
                    if result <= 9:
                        problem = f"{a} + {b} = {result}"
                        data.append(problem)
                        
                elif op_char == '-':
                    result = op_func(a, b)
                    # Constraint: Result must be non-negative (>= 0) and <= 9
                    if 0 <= result <= 9:
                        problem = f"{a} - {b} = {result}"
                        data.append(problem)
                    
                elif op_char == '*':
                    result = op_func(a, b)
                    # Constraint: Product must be a single digit (<= 9)
                    if result <= 9:
                        problem = f"{a} * {b} = {result}"
                        data.append(problem)
                        
                elif op_char == '/':
                    # Cannot divide by zero
                    if b == 0:
                        continue 
                    result = op_func(a, b)
                    # Constraint: Result must be a whole number (no remainder) AND a single digit (<= 9)
                    if a % b == 0 and result <= 9:
                        # Use int() to remove potential float from division result
                        problem = f"{a} / {b} = {int(result)}" 
                        data.append(problem)
                        
    # IMPORTANT: Shuffle and add <EOS> marker
    random.shuffle(data)
    final_data = [d + "<EOS>" for d in data]
    
    return final_data

class CharacterTokenizer:
    """A simple character-level tokenizer for the math problems."""
    
    def __init__(self, data):
        # 1. Build the unique vocabulary from the entire dataset
        # We need to make sure the data is generated first!
        chars = sorted(list(set("".join(data))))
        
        # Add a Padding token for PyTorch batching
        if '<PAD>' not in chars:
             chars.append('<PAD>') 

        self.stoi = {ch: i for i, ch in enumerate(chars)}
        self.itos = {i: ch for i, ch in enumerate(chars)}
        self.vocab_size = len(chars)
        self.pad_token_id = self.stoi['<PAD>']

    def encode(self, s):
        """Encodes a string into a list of integers."""
        return [self.stoi[c] for c in s]

    def decode(self, l):
        """Decodes a list of integers back into a string."""
        return "".join([self.itos[i] for i in l])