mamba / tokenizer.py
flpelerin's picture
Update file tokenizer.py
93a6c8c
from collections import Counter
import struct
import re
import os
import subprocess
class Token:
def __init__(self, byte, prev):
self.byte = byte
self.prev = prev
def pack(self):
if not 0 <= ord(self.byte) <= 255:
raise ValueError(f"Byte value is out of range, got {self.byte} ({ord(self.byte)})")
return struct.pack("=B H", ord(self.byte), self.prev)
def __str__(self):
return f"{self.byte}, {self.prev}"
def to_binary(self):
return self.pack()
class Tokenizer:
def __init__(self):
self.vocab = [Token(chr(i), 0) for i in range(256)] # define base vocab from ASCII values
def find(self, byte, prev):
for i in range(prev, self.vocab_size):
token = self.vocab[i]
if token.byte == byte and token.prev == prev:
return i
return 0
def append(self, byte, prev):
token = self.find(byte, prev)
if token:
return token
self.vocab.append(Token(byte, prev))
return self.vocab_size - 1
def encode_one(self, text):
prev = 0
for i in range(len(text)):
byte = text[i]
token = self.find(byte, prev)
if token == 0:
return prev, text[i:]
prev = token
return prev, ''
def encode(self, text):
ids = []
while text:
token, text = self.encode_one(text)
ids.append(token)
return ids
def decode_one(self, token):
text = ""
while token:
text += self.vocab[token].byte
token = self.vocab[token].prev
return text[::-1]
def decode(self, ids):
text = ""
for token in ids:
text += self.decode_one(token)
return text
def add_special(self, text):
#print(f"Encoding string: {text}")
token = ord(text[0])
for byte in text[1:]:
token = self.append(byte, token)
#print(f"Working on byte {byte}")
@property
def vocab_size(self):
return len(self.vocab)
def __str__(self):
return '[' + ', '.join(str(token) for token in self.vocab) + ']'
def to_file(self, file):
with open(file, 'ab') as f:
for token in self.vocab:
f.write(token.to_binary())
def from_file(self, file):
self.clear()
with open(file, 'rb') as f:
while True:
try:
data = f.read(3)
token = Token.from_binary(data)
self.vocab += token
except ValueError:
break
def train(self, text, max_length=32000):
words = text.split()
words = [' ' + ''.join(re.findall(r'\w', word)) for word in words]
words = [word for word in words if len(word) >= 2]
word_freq = Counter(words)
sorted_words = sorted(word_freq, key=lambda x: (-word_freq[x], x))
for word in sorted_words:
if self.vocab_size > max_length:
break
self.add_special(word)
print(f"adding word: {word} | current vocab size: {self.vocab_size} | max length: {max_length}")
# Weak part of the project. Maybe implement a handler?
def c_compile(self, c_dir):
subprocess.run(['make'], cwd=c_dir)
def c_run(self, c_dir, c_data, c_out):
subprocess.run(['./a.out', c_data, c_out], cwd=c_dir)
def load_binary_file(self, file_path):
with open(file_path, 'rb') as file:
data = file.read()
# Assuming uint16_t is 2 bytes long
num_values = len(data) // 2
values = struct.unpack('H' * num_values, data)
return list(values)
def c_encode(self, text):
script_dir = os.path.dirname(__file__)
c_dir = os.path.join(script_dir, 'c_tokenizer/')
c_vocab = c_dir + 'tokenizer.bin'
c_data = c_dir + 'dataset.txt'
c_out = c_dir + 'dataset.bin'
with open(c_data, 'w') as f:
f.write(text)
self.to_file(c_vocab)
self.c_compile(c_dir)
self.c_run(c_dir, c_data, c_out)
ids = self.load_binary_file(c_out)
return ids