import os import struct # Constants MAX_VOCAB_SIZE = 32000 MAX_WORD_LEN = 16 def ERROR(message, *args): """Prints an error message to stderr and exits.""" import sys sys.stderr.write(message % args) sys.exit(1) def INFO(message, *args): """Prints an informational message to stdout.""" print(message % args) class Tokenizer: def __init__(self, fname=None): self.vocab_size = 0 self.vocab = [''] * MAX_VOCAB_SIZE # Preallocate vocab with empty strings if fname: self.load_tokenizer(fname) INFO("vocabulary size: %d (%d max)", self.vocab_size, MAX_VOCAB_SIZE) INFO("max token length: %d", MAX_WORD_LEN) # Approximate size of structure: vocab_size * MAX_WORD_LEN + overhead structure_size = self.vocab_size * MAX_WORD_LEN INFO("size of structure: %d bytes", structure_size) def add_word(self, word): """Adds a word to the vocabulary.""" if self.vocab_size >= MAX_VOCAB_SIZE: return -1 # Truncate word if it's longer than MAX_WORD_LEN - 1 if len(word) >= MAX_WORD_LEN: word = word[:MAX_WORD_LEN - 1] self.vocab[self.vocab_size] = word self.vocab_size += 1 return self.vocab_size - 1 def encode_word(self, word): """Encodes a word into its corresponding ID using binary search.""" left = 0 right = self.vocab_size - 1 while left <= right: mid = left + (right - left) // 2 cmp = self._compare(word, self.vocab[mid]) if cmp == 0: return mid elif cmp < 0: right = mid - 1 else: left = mid + 1 return -1 def encode_stream(self, stream): """ Encodes a word from a stream. Args: stream (list of str): A list containing the characters of the stream. Returns: int: The ID of the encoded word. """ word = '' id = -1 j = 0 for i in range(min(MAX_WORD_LEN, len(stream))): word += stream[i] tmp = self.encode_word(word) if tmp != -1: id = tmp j = i + 1 # Modify the stream in-place to remove the processed characters del stream[:j] return id def encode_file(self, fd): """ Encodes a word from a file descriptor. Args: fd (file object): The file to encode from. Returns: int: The ID of the encoded word. """ word = '' id = -1 j = 0 for _ in range(MAX_WORD_LEN): c = fd.read(1) if not c: break char = c.decode('utf-8', errors='ignore') word += char tmp = self.encode_word(word) if tmp != -1: id = tmp j = len(word) # Seek back the remaining characters to_seek = MAX_WORD_LEN - j if to_seek > 0: fd.seek(-to_seek, os.SEEK_CUR) return id def decode(self, id): """Decodes an ID back into its corresponding word.""" if 0 <= id < self.vocab_size: return self.vocab[id] return None def decode_file(self, fd): """ Decodes an ID read from a file descriptor back into its corresponding word. Args: fd (file object): The file to decode from. Returns: str: The decoded word. """ data = fd.read(4) # Read 4 bytes for an integer if len(data) < 4: ERROR("read EOF from file\n") id = struct.unpack('i', data)[0] return self.decode(id) def save_vocab(self, fname): """Saves the vocabulary to a text file, one word per line.""" try: with open(fname, 'w', encoding='utf-8') as f: max_len = 0 for i in range(self.vocab_size): word = self.vocab[i] f.write(word + '\n') if len(word) > max_len: max_len = len(word) INFO("wrote %d tokens to file \"%s\"\nMax token length was %d", self.vocab_size, fname, max_len) except IOError as e: ERROR("failed to write to \"%s\": %s\n", fname, str(e)) def load_vocab(self, fname): """Loads the vocabulary from a text file, expecting one word per line.""" try: with open(fname, 'r', encoding='utf-8') as f: for line in f: word = line.strip() if word: self.add_word(word) except IOError as e: ERROR("failed to open \"%s\": %s\n", fname, str(e)) def save_tokenizer(self, fname): """Saves the tokenizer's vocabulary to a binary file.""" try: with open(fname, 'wb') as f: for i in range(MAX_VOCAB_SIZE): if i < self.vocab_size: word = self.vocab[i].encode('utf-8') if len(word) >= MAX_WORD_LEN: word = word[:MAX_WORD_LEN - 1] word += b'\0' * (MAX_WORD_LEN - len(word)) else: word = b'\0' * MAX_WORD_LEN f.write(word) INFO("wrote %d bytes (%d tokens) to \"%s\"", MAX_VOCAB_SIZE * MAX_WORD_LEN, self.vocab_size, fname) except IOError as e: ERROR("failed to write to \"%s\": %s\n", fname, str(e)) def load_tokenizer(self, fname): """Loads the tokenizer's vocabulary from a binary file.""" try: with open(fname, 'rb') as f: for i in range(MAX_VOCAB_SIZE): bytes_word = f.read(MAX_WORD_LEN) if not bytes_word or len(bytes_word) < MAX_WORD_LEN: break # Decode up to the first null byte word = bytes_word.split(b'\0', 1)[0].decode('utf-8', errors='ignore') if word: self.vocab[i] = word self.vocab_size += 1 INFO("read %d bytes (%d tokens) from \"%s\"", self.vocab_size * MAX_WORD_LEN, self.vocab_size, fname) except IOError as e: ERROR("failed to read from \"%s\": %s\n", fname, str(e)) @staticmethod def _compare(a, b): """Helper method to compare two strings similar to strcmp in C.""" return (a > b) - (a < b)