|
|
import os |
|
|
import struct |
|
|
|
|
|
|
|
|
MAX_VOCAB_SIZE = 32000 |
|
|
MAX_WORD_LEN = 16 |
|
|
|
|
|
def ERROR(message, *args): |
|
|
"""Prints an error message to stderr and exits.""" |
|
|
import sys |
|
|
sys.stderr.write(message % args) |
|
|
sys.exit(1) |
|
|
|
|
|
def INFO(message, *args): |
|
|
"""Prints an informational message to stdout.""" |
|
|
print(message % args) |
|
|
|
|
|
class Tokenizer: |
|
|
def __init__(self, fname=None): |
|
|
self.vocab_size = 0 |
|
|
self.vocab = [''] * MAX_VOCAB_SIZE |
|
|
|
|
|
if fname: |
|
|
self.load_tokenizer(fname) |
|
|
|
|
|
INFO("vocabulary size: %d (%d max)", self.vocab_size, MAX_VOCAB_SIZE) |
|
|
INFO("max token length: %d", MAX_WORD_LEN) |
|
|
|
|
|
structure_size = self.vocab_size * MAX_WORD_LEN |
|
|
INFO("size of structure: %d bytes", structure_size) |
|
|
|
|
|
def add_word(self, word): |
|
|
"""Adds a word to the vocabulary.""" |
|
|
if self.vocab_size >= MAX_VOCAB_SIZE: |
|
|
return -1 |
|
|
|
|
|
if len(word) >= MAX_WORD_LEN: |
|
|
word = word[:MAX_WORD_LEN - 1] |
|
|
self.vocab[self.vocab_size] = word |
|
|
self.vocab_size += 1 |
|
|
return self.vocab_size - 1 |
|
|
|
|
|
def encode_word(self, word): |
|
|
"""Encodes a word into its corresponding ID using binary search.""" |
|
|
left = 0 |
|
|
right = self.vocab_size - 1 |
|
|
|
|
|
while left <= right: |
|
|
mid = left + (right - left) // 2 |
|
|
cmp = self._compare(word, self.vocab[mid]) |
|
|
|
|
|
if cmp == 0: |
|
|
return mid |
|
|
elif cmp < 0: |
|
|
right = mid - 1 |
|
|
else: |
|
|
left = mid + 1 |
|
|
|
|
|
return -1 |
|
|
|
|
|
def encode_stream(self, stream): |
|
|
""" |
|
|
Encodes a word from a stream. |
|
|
|
|
|
Args: |
|
|
stream (list of str): A list containing the characters of the stream. |
|
|
|
|
|
Returns: |
|
|
int: The ID of the encoded word. |
|
|
""" |
|
|
word = '' |
|
|
id = -1 |
|
|
j = 0 |
|
|
|
|
|
for i in range(min(MAX_WORD_LEN, len(stream))): |
|
|
word += stream[i] |
|
|
tmp = self.encode_word(word) |
|
|
if tmp != -1: |
|
|
id = tmp |
|
|
j = i + 1 |
|
|
|
|
|
|
|
|
del stream[:j] |
|
|
|
|
|
return id |
|
|
|
|
|
def encode_file(self, fd): |
|
|
""" |
|
|
Encodes a word from a file descriptor. |
|
|
|
|
|
Args: |
|
|
fd (file object): The file to encode from. |
|
|
|
|
|
Returns: |
|
|
int: The ID of the encoded word. |
|
|
""" |
|
|
word = '' |
|
|
id = -1 |
|
|
j = 0 |
|
|
|
|
|
for _ in range(MAX_WORD_LEN): |
|
|
c = fd.read(1) |
|
|
if not c: |
|
|
break |
|
|
char = c.decode('utf-8', errors='ignore') |
|
|
word += char |
|
|
tmp = self.encode_word(word) |
|
|
if tmp != -1: |
|
|
id = tmp |
|
|
j = len(word) |
|
|
|
|
|
|
|
|
to_seek = MAX_WORD_LEN - j |
|
|
if to_seek > 0: |
|
|
fd.seek(-to_seek, os.SEEK_CUR) |
|
|
|
|
|
return id |
|
|
|
|
|
def decode(self, id): |
|
|
"""Decodes an ID back into its corresponding word.""" |
|
|
if 0 <= id < self.vocab_size: |
|
|
return self.vocab[id] |
|
|
return None |
|
|
|
|
|
def decode_file(self, fd): |
|
|
""" |
|
|
Decodes an ID read from a file descriptor back into its corresponding word. |
|
|
|
|
|
Args: |
|
|
fd (file object): The file to decode from. |
|
|
|
|
|
Returns: |
|
|
str: The decoded word. |
|
|
""" |
|
|
data = fd.read(4) |
|
|
if len(data) < 4: |
|
|
ERROR("read EOF from file\n") |
|
|
|
|
|
id = struct.unpack('i', data)[0] |
|
|
return self.decode(id) |
|
|
|
|
|
def save_vocab(self, fname): |
|
|
"""Saves the vocabulary to a text file, one word per line.""" |
|
|
try: |
|
|
with open(fname, 'w', encoding='utf-8') as f: |
|
|
max_len = 0 |
|
|
for i in range(self.vocab_size): |
|
|
word = self.vocab[i] |
|
|
f.write(word + '\n') |
|
|
if len(word) > max_len: |
|
|
max_len = len(word) |
|
|
INFO("wrote %d tokens to file \"%s\"\nMax token length was %d", |
|
|
self.vocab_size, fname, max_len) |
|
|
except IOError as e: |
|
|
ERROR("failed to write to \"%s\": %s\n", fname, str(e)) |
|
|
|
|
|
def load_vocab(self, fname): |
|
|
"""Loads the vocabulary from a text file, expecting one word per line.""" |
|
|
try: |
|
|
with open(fname, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
word = line.strip() |
|
|
if word: |
|
|
self.add_word(word) |
|
|
except IOError as e: |
|
|
ERROR("failed to open \"%s\": %s\n", fname, str(e)) |
|
|
|
|
|
def save_tokenizer(self, fname): |
|
|
"""Saves the tokenizer's vocabulary to a binary file.""" |
|
|
try: |
|
|
with open(fname, 'wb') as f: |
|
|
for i in range(MAX_VOCAB_SIZE): |
|
|
if i < self.vocab_size: |
|
|
word = self.vocab[i].encode('utf-8') |
|
|
if len(word) >= MAX_WORD_LEN: |
|
|
word = word[:MAX_WORD_LEN - 1] |
|
|
word += b'\0' * (MAX_WORD_LEN - len(word)) |
|
|
else: |
|
|
word = b'\0' * MAX_WORD_LEN |
|
|
f.write(word) |
|
|
INFO("wrote %d bytes (%d tokens) to \"%s\"", |
|
|
MAX_VOCAB_SIZE * MAX_WORD_LEN, self.vocab_size, fname) |
|
|
except IOError as e: |
|
|
ERROR("failed to write to \"%s\": %s\n", fname, str(e)) |
|
|
|
|
|
def load_tokenizer(self, fname): |
|
|
"""Loads the tokenizer's vocabulary from a binary file.""" |
|
|
try: |
|
|
with open(fname, 'rb') as f: |
|
|
for i in range(MAX_VOCAB_SIZE): |
|
|
bytes_word = f.read(MAX_WORD_LEN) |
|
|
if not bytes_word or len(bytes_word) < MAX_WORD_LEN: |
|
|
break |
|
|
|
|
|
word = bytes_word.split(b'\0', 1)[0].decode('utf-8', errors='ignore') |
|
|
if word: |
|
|
self.vocab[i] = word |
|
|
self.vocab_size += 1 |
|
|
INFO("read %d bytes (%d tokens) from \"%s\"", |
|
|
self.vocab_size * MAX_WORD_LEN, self.vocab_size, fname) |
|
|
except IOError as e: |
|
|
ERROR("failed to read from \"%s\": %s\n", fname, str(e)) |
|
|
|
|
|
@staticmethod |
|
|
def _compare(a, b): |
|
|
"""Helper method to compare two strings similar to strcmp in C.""" |
|
|
return (a > b) - (a < b) |
|
|
|