File size: 6,740 Bytes
cdd1aa0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import os
import struct
# Constants
MAX_VOCAB_SIZE = 32000
MAX_WORD_LEN = 16
def ERROR(message, *args):
"""Prints an error message to stderr and exits."""
import sys
sys.stderr.write(message % args)
sys.exit(1)
def INFO(message, *args):
"""Prints an informational message to stdout."""
print(message % args)
class Tokenizer:
def __init__(self, fname=None):
self.vocab_size = 0
self.vocab = [''] * MAX_VOCAB_SIZE # Preallocate vocab with empty strings
if fname:
self.load_tokenizer(fname)
INFO("vocabulary size: %d (%d max)", self.vocab_size, MAX_VOCAB_SIZE)
INFO("max token length: %d", MAX_WORD_LEN)
# Approximate size of structure: vocab_size * MAX_WORD_LEN + overhead
structure_size = self.vocab_size * MAX_WORD_LEN
INFO("size of structure: %d bytes", structure_size)
def add_word(self, word):
"""Adds a word to the vocabulary."""
if self.vocab_size >= MAX_VOCAB_SIZE:
return -1
# Truncate word if it's longer than MAX_WORD_LEN - 1
if len(word) >= MAX_WORD_LEN:
word = word[:MAX_WORD_LEN - 1]
self.vocab[self.vocab_size] = word
self.vocab_size += 1
return self.vocab_size - 1
def encode_word(self, word):
"""Encodes a word into its corresponding ID using binary search."""
left = 0
right = self.vocab_size - 1
while left <= right:
mid = left + (right - left) // 2
cmp = self._compare(word, self.vocab[mid])
if cmp == 0:
return mid
elif cmp < 0:
right = mid - 1
else:
left = mid + 1
return -1
def encode_stream(self, stream):
"""
Encodes a word from a stream.
Args:
stream (list of str): A list containing the characters of the stream.
Returns:
int: The ID of the encoded word.
"""
word = ''
id = -1
j = 0
for i in range(min(MAX_WORD_LEN, len(stream))):
word += stream[i]
tmp = self.encode_word(word)
if tmp != -1:
id = tmp
j = i + 1
# Modify the stream in-place to remove the processed characters
del stream[:j]
return id
def encode_file(self, fd):
"""
Encodes a word from a file descriptor.
Args:
fd (file object): The file to encode from.
Returns:
int: The ID of the encoded word.
"""
word = ''
id = -1
j = 0
for _ in range(MAX_WORD_LEN):
c = fd.read(1)
if not c:
break
char = c.decode('utf-8', errors='ignore')
word += char
tmp = self.encode_word(word)
if tmp != -1:
id = tmp
j = len(word)
# Seek back the remaining characters
to_seek = MAX_WORD_LEN - j
if to_seek > 0:
fd.seek(-to_seek, os.SEEK_CUR)
return id
def decode(self, id):
"""Decodes an ID back into its corresponding word."""
if 0 <= id < self.vocab_size:
return self.vocab[id]
return None
def decode_file(self, fd):
"""
Decodes an ID read from a file descriptor back into its corresponding word.
Args:
fd (file object): The file to decode from.
Returns:
str: The decoded word.
"""
data = fd.read(4) # Read 4 bytes for an integer
if len(data) < 4:
ERROR("read EOF from file\n")
id = struct.unpack('i', data)[0]
return self.decode(id)
def save_vocab(self, fname):
"""Saves the vocabulary to a text file, one word per line."""
try:
with open(fname, 'w', encoding='utf-8') as f:
max_len = 0
for i in range(self.vocab_size):
word = self.vocab[i]
f.write(word + '\n')
if len(word) > max_len:
max_len = len(word)
INFO("wrote %d tokens to file \"%s\"\nMax token length was %d",
self.vocab_size, fname, max_len)
except IOError as e:
ERROR("failed to write to \"%s\": %s\n", fname, str(e))
def load_vocab(self, fname):
"""Loads the vocabulary from a text file, expecting one word per line."""
try:
with open(fname, 'r', encoding='utf-8') as f:
for line in f:
word = line.strip()
if word:
self.add_word(word)
except IOError as e:
ERROR("failed to open \"%s\": %s\n", fname, str(e))
def save_tokenizer(self, fname):
"""Saves the tokenizer's vocabulary to a binary file."""
try:
with open(fname, 'wb') as f:
for i in range(MAX_VOCAB_SIZE):
if i < self.vocab_size:
word = self.vocab[i].encode('utf-8')
if len(word) >= MAX_WORD_LEN:
word = word[:MAX_WORD_LEN - 1]
word += b'\0' * (MAX_WORD_LEN - len(word))
else:
word = b'\0' * MAX_WORD_LEN
f.write(word)
INFO("wrote %d bytes (%d tokens) to \"%s\"",
MAX_VOCAB_SIZE * MAX_WORD_LEN, self.vocab_size, fname)
except IOError as e:
ERROR("failed to write to \"%s\": %s\n", fname, str(e))
def load_tokenizer(self, fname):
"""Loads the tokenizer's vocabulary from a binary file."""
try:
with open(fname, 'rb') as f:
for i in range(MAX_VOCAB_SIZE):
bytes_word = f.read(MAX_WORD_LEN)
if not bytes_word or len(bytes_word) < MAX_WORD_LEN:
break
# Decode up to the first null byte
word = bytes_word.split(b'\0', 1)[0].decode('utf-8', errors='ignore')
if word:
self.vocab[i] = word
self.vocab_size += 1
INFO("read %d bytes (%d tokens) from \"%s\"",
self.vocab_size * MAX_WORD_LEN, self.vocab_size, fname)
except IOError as e:
ERROR("failed to read from \"%s\": %s\n", fname, str(e))
@staticmethod
def _compare(a, b):
"""Helper method to compare two strings similar to strcmp in C."""
return (a > b) - (a < b)
|