BPE-Tokenizer / tokenizer.py
anveshplus's picture
updated
667a4ff
import pandas as pd
import re
import encoder_parallel_telugu as encode_parallel
import time
import json
from collections import defaultdict
from tqdm import tqdm
def load_and_encode_tokens():
tokens = encode_parallel.load_telugu_texts()
start_time = time.time()
encoded_tokens = encode_parallel.encode_tokens_parallel(tokens, chunk_size=1_000_000, max_workers=10)
print('encoded_tokens:', encoded_tokens[:100])
print(len(encoded_tokens))
end_time = time.time()
print(f"Time taken to encode and process tokens in parallel: {end_time - start_time:.4f} seconds")
print('length of encoded_text:', len(encoded_tokens))
print('unique characters in decoded_text:', {token.decode('utf-8') for token in set(encoded_tokens)})
# print('unique characters in encoded_text:', set(encoded_tokens))
print('unique characters in encoded_text:', len(set(encoded_tokens)))
return encoded_tokens
def get_stats(ids):
counts = {}
for pair in zip(ids, ids[1:]):
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(ids, pair, idx):
new_ids = []
i = 0
while i < len(ids):
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i + 1] == pair[1]:
new_ids.append(idx)
i += 2
else:
new_ids.append(ids[i])
i += 1
return new_ids
def bpe_process(encoded_tokens,vocab_size=500, encoded_tokens_length=10_00_000):
num_merges = vocab_size - 256 # our unique tokens are 194, for our sample text.
encoded_tokens = encoded_tokens[:encoded_tokens_length]
ids = list(encoded_tokens) # copy so we don't destroy the original list
merges = {} # (int, int) -> int
for i in tqdm(range(num_merges), desc="Merging tokens"):
stats = get_stats(ids)
pair = max(stats, key=stats.get)
idx = 256 + i
ids = merge(ids, pair, idx)
merges[pair] = idx # merge has a pair of tokens and the new token index
print("tokens length:", len(encoded_tokens))
print("ids length:", len(ids))
print("by paired tokens length:", len(set(ids)))
print(f"compression ratio: {len(encoded_tokens) / len(ids):.2f}X")
# print(f"token size: {len(set(encoded_tokens))}")
return merges
def build_vocabulary(merges):
telugu_unicode_chars = [chr(i) for i in range(0x0C00, 0x0C7F)] # Telugu Unicode range
vocab = {token: idx for token, idx in merges.items()}
for idx, char in enumerate([chr(i).encode('utf-8') for i in range(0x0C00, 0x0C7F)]):
if idx < 256: # Ensure we only add up to 256 characters
vocab[char] = idx # Map the character to its index
vocab[b' '] = 255
vocab[b'.'] = 254
with open('merges_vocab.json', 'w') as f:
json.dump({'merges': {str(k): v for k, v in merges.items()}, 'vocab': {str(k): v for k, v in vocab.items()}}, f)
def read_vocab_from_file():
with open('merges_vocab.json', 'r') as f:
data = json.load(f)
distributed_data = defaultdict(list)
for key, value in data['vocab'].items():
distributed_data['vocab'].append({key: value})
formatted_vocab = {}
for item in distributed_data['vocab']:
for k, v in item.items():
if ',' not in k:
formatted_vocab[(eval(k),)] = v
else:
formatted_vocab[eval(k)] = v
return formatted_vocab
def expand_vocab(inverted_vocab):
def convert_to_bytes(value):
if isinstance(value, bytes):
return value
elif value in inverted_vocab:
return process_tuple(inverted_vocab[value])
else:
print(f'value not found in inverted_vocab: {value}')
return None
def process_tuple(value_tuple):
converted_values = []
for v in value_tuple:
result = convert_to_bytes(v)
if isinstance(result, tuple):
converted_values.extend(result)
else:
converted_values.append(result)
return tuple(converted_values)
decoder_map = {k: process_tuple(v) for k, v in inverted_vocab.items()}
print("sample decoder map:", {k: decoder_map[k] for k in list(decoder_map)[:5]})
return decoder_map
# Main execution
if __name__ == "__main__":
# 1. Load and encode tokens
encoded_tokens = load_and_encode_tokens()
# 2. Process BPE
merges = bpe_process(encoded_tokens,vocab_size=5000, encoded_tokens_length=20_00_000)
# 3. Build vocabulary
build_vocabulary(merges)
# 4. Read vocabulary from file
formatted_vocab = read_vocab_from_file()
# 5. Invert vocabulary
inverted_vocab = {v: k for k, v in formatted_vocab.items()}
# 6. Expand vocabulary
decoder_map = expand_vocab(inverted_vocab)
# 7. Invert back again
re_inverted_vocab = {k: v for v, k in decoder_map.items()}
# print(re_inverted_vocab)