File size: 6,740 Bytes
cdd1aa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import os
import struct

# Constants
MAX_VOCAB_SIZE = 32000
MAX_WORD_LEN = 16

def ERROR(message, *args):
    """Prints an error message to stderr and exits."""
    import sys
    sys.stderr.write(message % args)
    sys.exit(1)

def INFO(message, *args):
    """Prints an informational message to stdout."""
    print(message % args)

class Tokenizer:
    def __init__(self, fname=None):
        self.vocab_size = 0
        self.vocab = [''] * MAX_VOCAB_SIZE  # Preallocate vocab with empty strings

        if fname:
            self.load_tokenizer(fname)

        INFO("vocabulary size: %d (%d max)", self.vocab_size, MAX_VOCAB_SIZE)
        INFO("max token length: %d", MAX_WORD_LEN)
        # Approximate size of structure: vocab_size * MAX_WORD_LEN + overhead
        structure_size = self.vocab_size * MAX_WORD_LEN
        INFO("size of structure: %d bytes", structure_size)

    def add_word(self, word):
        """Adds a word to the vocabulary."""
        if self.vocab_size >= MAX_VOCAB_SIZE:
            return -1
        # Truncate word if it's longer than MAX_WORD_LEN - 1
        if len(word) >= MAX_WORD_LEN:
            word = word[:MAX_WORD_LEN - 1]
        self.vocab[self.vocab_size] = word
        self.vocab_size += 1
        return self.vocab_size - 1

    def encode_word(self, word):
        """Encodes a word into its corresponding ID using binary search."""
        left = 0
        right = self.vocab_size - 1

        while left <= right:
            mid = left + (right - left) // 2
            cmp = self._compare(word, self.vocab[mid])

            if cmp == 0:
                return mid
            elif cmp < 0:
                right = mid - 1
            else:
                left = mid + 1

        return -1

    def encode_stream(self, stream):
        """
        Encodes a word from a stream.
        
        Args:
            stream (list of str): A list containing the characters of the stream.
        
        Returns:
            int: The ID of the encoded word.
        """
        word = ''
        id = -1
        j = 0

        for i in range(min(MAX_WORD_LEN, len(stream))):
            word += stream[i]
            tmp = self.encode_word(word)
            if tmp != -1:
                id = tmp
                j = i + 1

        # Modify the stream in-place to remove the processed characters
        del stream[:j]

        return id

    def encode_file(self, fd):
        """
        Encodes a word from a file descriptor.
        
        Args:
            fd (file object): The file to encode from.
        
        Returns:
            int: The ID of the encoded word.
        """
        word = ''
        id = -1
        j = 0

        for _ in range(MAX_WORD_LEN):
            c = fd.read(1)
            if not c:
                break
            char = c.decode('utf-8', errors='ignore')
            word += char
            tmp = self.encode_word(word)
            if tmp != -1:
                id = tmp
                j = len(word)

        # Seek back the remaining characters
        to_seek = MAX_WORD_LEN - j
        if to_seek > 0:
            fd.seek(-to_seek, os.SEEK_CUR)

        return id

    def decode(self, id):
        """Decodes an ID back into its corresponding word."""
        if 0 <= id < self.vocab_size:
            return self.vocab[id]
        return None

    def decode_file(self, fd):
        """
        Decodes an ID read from a file descriptor back into its corresponding word.
        
        Args:
            fd (file object): The file to decode from.
        
        Returns:
            str: The decoded word.
        """
        data = fd.read(4)  # Read 4 bytes for an integer
        if len(data) < 4:
            ERROR("read EOF from file\n")
        
        id = struct.unpack('i', data)[0]
        return self.decode(id)

    def save_vocab(self, fname):
        """Saves the vocabulary to a text file, one word per line."""
        try:
            with open(fname, 'w', encoding='utf-8') as f:
                max_len = 0
                for i in range(self.vocab_size):
                    word = self.vocab[i]
                    f.write(word + '\n')
                    if len(word) > max_len:
                        max_len = len(word)
            INFO("wrote %d tokens to file \"%s\"\nMax token length was %d",
                 self.vocab_size, fname, max_len)
        except IOError as e:
            ERROR("failed to write to \"%s\": %s\n", fname, str(e))

    def load_vocab(self, fname):
        """Loads the vocabulary from a text file, expecting one word per line."""
        try:
            with open(fname, 'r', encoding='utf-8') as f:
                for line in f:
                    word = line.strip()
                    if word:
                        self.add_word(word)
        except IOError as e:
            ERROR("failed to open \"%s\": %s\n", fname, str(e))

    def save_tokenizer(self, fname):
        """Saves the tokenizer's vocabulary to a binary file."""
        try:
            with open(fname, 'wb') as f:
                for i in range(MAX_VOCAB_SIZE):
                    if i < self.vocab_size:
                        word = self.vocab[i].encode('utf-8')
                        if len(word) >= MAX_WORD_LEN:
                            word = word[:MAX_WORD_LEN - 1]
                        word += b'\0' * (MAX_WORD_LEN - len(word))
                    else:
                        word = b'\0' * MAX_WORD_LEN
                    f.write(word)
            INFO("wrote %d bytes (%d tokens) to \"%s\"",
                 MAX_VOCAB_SIZE * MAX_WORD_LEN, self.vocab_size, fname)
        except IOError as e:
            ERROR("failed to write to \"%s\": %s\n", fname, str(e))

    def load_tokenizer(self, fname):
        """Loads the tokenizer's vocabulary from a binary file."""
        try:
            with open(fname, 'rb') as f:
                for i in range(MAX_VOCAB_SIZE):
                    bytes_word = f.read(MAX_WORD_LEN)
                    if not bytes_word or len(bytes_word) < MAX_WORD_LEN:
                        break
                    # Decode up to the first null byte
                    word = bytes_word.split(b'\0', 1)[0].decode('utf-8', errors='ignore')
                    if word:
                        self.vocab[i] = word
                        self.vocab_size += 1
            INFO("read %d bytes (%d tokens) from \"%s\"",
                 self.vocab_size * MAX_WORD_LEN, self.vocab_size, fname)
        except IOError as e:
            ERROR("failed to read from \"%s\": %s\n", fname, str(e))

    @staticmethod
    def _compare(a, b):
        """Helper method to compare two strings similar to strcmp in C."""
        return (a > b) - (a < b)