MixtureofRecursionwithRouter / custom_tokenizer.py
Girinath11's picture
Create custom_tokenizer.py
613f2bb verified
raw
history blame
21.2 kB
import os
import json
import pickle
import argparse
from collections import Counter, defaultdict
from typing import List, Dict, Set, Optional, Tuple
import re
import unicodedata
class TechnicalTokenizer:
"""
Custom tokenizer optimized for technical content and conversations
"""
def __init__(self, vocab_size: int = 32000, min_freq: int = 2):
self.vocab_size = vocab_size
self.min_freq = min_freq
self.special_tokens = {
'<pad>': 0,
'<unk>': 1,
'<bos>': 2,
'<eos>': 3,
'<system>': 4,
'<user>': 5,
'<assistant>': 6,
'<|endoftext|>': 7,
'<|newline|>': 8,
'<|tab|>': 9,
'<|code|>': 10,
'<|/code|>': 11,
'<|math|>': 12,
'<|/math|>': 13
}
self.vocab = {}
self.id_to_token = {}
self.token_frequencies = Counter()
self.bpe_merges = []
self.bpe_cache = {}
self.code_pattern = re.compile(r'```[\s\S]*?```|`[^`]+`')
self.url_pattern = re.compile(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')
self.email_pattern = re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b')
self.number_pattern = re.compile(r'\b\d+\.?\d*\b')
self.technical_terms = {
'function', 'variable', 'array', 'object', 'class', 'method', 'parameter',
'return', 'import', 'export', 'async', 'await', 'promise', 'callback',
'algorithm', 'datatype', 'boolean', 'integer', 'string', 'float',
'javascript', 'python', 'java', 'cpp', 'html', 'css', 'sql',
'api', 'json', 'xml', 'http', 'https', 'rest', 'graphql',
'equation', 'formula', 'theorem', 'proof', 'hypothesis',
'derivative', 'integral', 'matrix', 'vector', 'polynomial',
'probability', 'statistics', 'correlation', 'regression',
'neural', 'network', 'model', 'training', 'validation', 'test',
'accuracy', 'precision', 'recall', 'f1score', 'loss', 'gradient',
'backpropagation', 'forward', 'layer', 'neuron', 'weight', 'bias',
'transformer', 'attention', 'embedding', 'tokenization',
'database', 'server', 'client', 'protocol', 'encryption', 'security',
'authentication', 'authorization', 'deployment', 'docker', 'kubernetes',
'microservice', 'architecture', 'scalability', 'performance'
}
self._init_vocab()
def _init_vocab(self):
self.vocab = self.special_tokens.copy()
self.id_to_token = {v: k for k, v in self.special_tokens.items()}
def normalize_text(self, text: str) -> str:
text = re.sub(r'\r\n|\r', '\n', text)
text = re.sub(r'\t', '<|tab|>', text)
text = unicodedata.normalize('NFKC', text)
code_blocks = []
def replace_code(match):
code_blocks.append(match.group())
return f'<|code|>CODE_BLOCK_{len(code_blocks)-1}<|/code|>'
text = self.code_pattern.sub(replace_code, text)
text = self.url_pattern.sub('<URL>', text)
text = self.email_pattern.sub('<EMAIL>', text)
for i, code_block in enumerate(code_blocks):
text = text.replace(f'<|code|>CODE_BLOCK_{i}<|/code|>', code_block)
return text
def pre_tokenize(self, text: str) -> List[str]:
text = self.normalize_text(text)
text = re.sub(r'<\|system\|>', ' <system> ', text)
text = re.sub(r'<\|user\|>', ' <user> ', text)
text = re.sub(r'<\|assistant\|>', ' <assistant> ', text)
text = re.sub(r'<\|endoftext\|>', ' <|endoftext|> ', text)
tokens = re.findall(r'''
<[^>]+>| # Special tokens
\b\w+@\w+\.\w+\b| # Email-like patterns
https?://\S+| # URLs
```[\s\S]*?```| # Code blocks
`[^`]+`| # Inline code
\b\d+\.?\d*\b| # Numbers
\b[a-zA-Z]+(?:'[a-z]*)?| # Words with optional apostrophes
[^\w\s] # Punctuation
''', text, re.VERBOSE)
return [token.strip() for token in tokens if token.strip()]
def get_pairs(self, word_freqs: Dict[Tuple[str, ...], int]) -> Counter:
pairs = Counter()
for word, freq in word_freqs.items():
if len(word) < 2:
continue
for i in range(len(word) - 1):
pair = (word[i], word[i + 1])
pairs[pair] += freq
return pairs
def merge_symbols(self, pair: Tuple[str, str], word_freqs: Dict[Tuple[str, ...], int]) -> Dict[Tuple[str, ...], int]:
new_word_freqs = {}
bigram = pair
for word, freq in word_freqs.items():
new_word = []
i = 0
while i < len(word):
if i < len(word) - 1 and (word[i], word[i + 1]) == bigram:
new_word.append(word[i] + word[i + 1])
i += 2
else:
new_word.append(word[i])
i += 1
new_word_freqs[tuple(new_word)] = freq
return new_word_freqs
def train_bpe(self, texts: List[str]) -> None:
print("Training BPE tokenizer...")
word_freqs = Counter()
for i, text in enumerate(texts):
if i % 10000 == 0:
print(f"Processing text {i}/{len(texts)}")
tokens = self.pre_tokenize(text)
for token in tokens:
char_seq = tuple(token)
if len(char_seq) > 0:
word_freqs[char_seq] += 1
print(f"Found {len(word_freqs)} unique word patterns")
word_freqs = {word: freq for word, freq in word_freqs.items() if freq >= self.min_freq}
for term in self.technical_terms:
if (term,) in word_freqs:
word_freqs[(term,)] *= 10
all_chars = set()
for word in word_freqs:
all_chars.update(word)
for char in sorted(all_chars):
if char not in self.vocab:
self.vocab[char] = len(self.vocab)
self.id_to_token[len(self.id_to_token)] = char
target_vocab_size = self.vocab_size - len(self.special_tokens)
num_merges = target_vocab_size - len(self.vocab)
for i in range(num_merges):
if i % 1000 == 0:
print(f"BPE merge {i}/{num_merges}")
pairs = self.get_pairs(word_freqs)
if not pairs:
break
best_pair = pairs.most_common(1)[0][0]
word_freqs = self.merge_symbols(best_pair, word_freqs)
merged_token = best_pair[0] + best_pair[1]
if merged_token not in self.vocab:
self.vocab[merged_token] = len(self.vocab)
self.id_to_token[len(self.id_to_token)] = merged_token
self.bpe_merges.append(best_pair)
print(f"BPE training complete. Final vocabulary size: {len(self.vocab)}")
for word, freq in word_freqs.items():
for token in word:
self.token_frequencies[token] += freq
def apply_bpe(self, word: str) -> List[str]:
if word in self.bpe_cache:
return self.bpe_cache[word]
tokens = list(word)
for merge in self.bpe_merges:
i = 0
while i < len(tokens) - 1:
if tokens[i] == merge[0] and tokens[i + 1] == merge[1]:
tokens = tokens[:i] + [merge[0] + merge[1]] + tokens[i + 2:]
else:
i += 1
self.bpe_cache[word] = tokens
return tokens
def tokenize(self, text: str) -> List[str]:
pre_tokens = self.pre_tokenize(text)
final_tokens = []
for token in pre_tokens:
if token in self.special_tokens or token in self.vocab:
final_tokens.append(token)
else:
bpe_tokens = self.apply_bpe(token)
final_tokens.extend(bpe_tokens)
return final_tokens
def encode_ids(self, text: str, add_special_tokens: bool = True) -> List[int]:
tokens = self.tokenize(text)
if add_special_tokens:
tokens = ['<bos>'] + tokens + ['<eos>']
ids = []
for token in tokens:
ids.append(self.vocab.get(token, self.vocab['<unk>']))
return ids
def decode_ids(self, ids: List[int], skip_special_tokens: bool = True) -> str:
tokens = []
for id in ids:
token = self.id_to_token.get(id, '<unk>')
if skip_special_tokens and token in self.special_tokens:
continue
tokens.append(token)
text = ''.join(tokens)
text = text.replace('<|tab|>', '\t')
text = text.replace('<|newline|>', '\n')
return text
def save(self, save_dir: str):
os.makedirs(save_dir, exist_ok=True)
with open(os.path.join(save_dir, 'vocab.json'), 'w', encoding='utf-8') as f:
json.dump(self.vocab, f, indent=2, ensure_ascii=False)
with open(os.path.join(save_dir, 'merges.txt'), 'w', encoding='utf-8') as f:
for merge in self.bpe_merges:
f.write(f"{merge[0]} {merge[1]}\n")
config = {
'vocab_size': self.vocab_size,
'min_freq': self.min_freq,
'special_tokens': self.special_tokens,
'technical_terms': list(self.technical_terms)
}
with open(os.path.join(save_dir, 'tokenizer_config.json'), 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2, ensure_ascii=False)
with open(os.path.join(save_dir, 'token_frequencies.pkl'), 'wb') as f:
pickle.dump(dict(self.token_frequencies), f)
print(f"Tokenizer saved to {save_dir}")
def load(self, save_dir: str):
with open(os.path.join(save_dir, 'vocab.json'), 'r', encoding='utf-8') as f:
self.vocab = json.load(f)
self.id_to_token = {v: k for k, v in self.vocab.items()}
with open(os.path.join(save_dir, 'merges.txt'), 'r', encoding='utf-8') as f:
self.bpe_merges = [tuple(line.strip().split()) for line in f if line.strip()]
config_file = os.path.join(save_dir, 'tokenizer_config.json')
if os.path.exists(config_file):
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
self.vocab_size = config.get('vocab_size', self.vocab_size)
self.min_freq = config.get('min_freq', self.min_freq)
if 'technical_terms' in config:
self.technical_terms = set(config['technical_terms'])
freq_file = os.path.join(save_dir, 'token_frequencies.pkl')
if os.path.exists(freq_file):
with open(freq_file, 'rb') as f:
self.token_frequencies = Counter(pickle.load(f))
self.bpe_cache = {}
print(f"Tokenizer loaded from {save_dir}")
print(f"Vocabulary size: {len(self.vocab)}")
print(f"Number of BPE merges: {len(self.bpe_merges)}")
def get_vocab_size(self) -> int:
return len(self.vocab)
def get_token_frequency(self, token: str) -> int:
return self.token_frequencies.get(token, 0)
def analyze_tokenization(self, text: str):
tokens = self.tokenize(text)
ids = self.encode_ids(text, add_special_tokens=False)
print(f"Original text: {text}")
print(f"Tokens: {tokens}")
print(f"Token IDs: {ids}")
print(f"Number of tokens: {len(tokens)}")
print(f"Compression ratio: {len(text.split())/len(tokens):.2f}")
return tokens, ids
class ConversationDataset:
"""Dataset class for handling conversation data with the custom tokenizer"""
def __init__(self, data_file: str, tokenizer: TechnicalTokenizer, max_length: int = 512):
self.data_file = data_file
self.tokenizer = tokenizer
self.max_length = max_length
self.conversations = []
self.load_conversations()
def load_conversations(self):
print(f"Loading conversations from {self.data_file}")
if self.data_file.endswith('.jsonl'):
self.load_jsonl()
else:
self.load_text()
print(f"Loaded {len(self.conversations)} conversations")
def load_jsonl(self):
with open(self.data_file, 'r', encoding='utf-8') as f:
for line in f:
try:
conv = json.loads(line.strip())
messages = conv.get("messages", [])
if not messages:
continue
text_parts = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "").strip()
if not content:
continue
if role == "system":
continue
elif role == "user":
text_parts.append(f"<user> {content}")
elif role == "assistant":
text_parts.append(f"<assistant> {content}")
if len(text_parts) >= 2:
conversation_text = " ".join(text_parts) + " <|endoftext|>"
self.conversations.append(conversation_text)
except json.JSONDecodeError:
continue
def load_text(self):
with open(self.data_file, 'r', encoding='utf-8') as f:
content = f.read()
conversations = content.split('<|endoftext|>\n')
for conv in conversations:
conv = conv.strip()
if conv:
self.conversations.append(conv + " <|endoftext|>")
def get_tokenized_conversations(self, include_stats=False):
tokenized = []
stats = {'total_tokens': 0, 'truncated': 0, 'avg_length': 0}
for conv in self.conversations:
tokens = self.tokenizer.encode_ids(conv)
if len(tokens) > self.max_length:
tokens = tokens[:self.max_length]
stats['truncated'] += 1
tokenized.append(tokens)
stats['total_tokens'] += len(tokens)
if tokenized:
stats['avg_length'] = stats['total_tokens'] / len(tokenized)
if include_stats:
return tokenized, stats
return tokenized
def create_training_examples(self, stride: int = None):
if stride is None:
stride = self.max_length // 2
examples = []
for conv in self.conversations:
tokens = self.tokenizer.encode_ids(conv)
if len(tokens) <= self.max_length:
examples.append(tokens)
else:
for i in range(0, len(tokens), stride):
window = tokens[i:i + self.max_length]
if len(window) >= 32:
examples.append(window)
return examples
def train_tokenizer_from_files(file_paths: List[str],
vocab_size: int = 32000,
min_freq: int = 2,
output_dir: str = "tokenizer",
max_texts: int = None):
print(f"Training tokenizer with vocab_size={vocab_size}")
print(f"Input files: {file_paths}")
all_texts = []
for file_path in file_paths:
print(f"Loading {file_path}...")
if file_path.endswith('.jsonl'):
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
try:
conv = json.loads(line.strip())
messages = conv.get("messages", [])
text_parts = []
for msg in messages:
content = msg.get("content", "").strip()
if content:
text_parts.append(content)
if text_parts:
all_texts.append(" ".join(text_parts))
except json.JSONDecodeError:
continue
else:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
chunks = content.split('\n\n')
for chunk in chunks:
if chunk.strip():
all_texts.append(chunk.strip())
print(f"Loaded {len(all_texts)} texts")
if max_texts and len(all_texts) > max_texts:
import random
random.shuffle(all_texts)
all_texts = all_texts[:max_texts]
print(f"Limited to {len(all_texts)} texts")
tokenizer = TechnicalTokenizer(vocab_size=vocab_size, min_freq=min_freq)
tokenizer.train_bpe(all_texts)
tokenizer.save(output_dir)
print("\nTesting tokenization on sample texts:")
test_texts = [
"Hello, how can I help you with your Python programming question?",
"The neural network has 3 hidden layers with ReLU activation functions.",
"```python\ndef fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)\n```",
"The derivative of x^2 is 2x, and the integral is (x^3)/3 + C."
]
for text in test_texts:
tokenizer.analyze_tokenization(text)
print()
return tokenizer
def main():
parser = argparse.ArgumentParser(description="Train custom tokenizer for technical content")
parser.add_argument("--input_files", nargs='+', help="Input text/jsonl files")
parser.add_argument("--output_dir", default="tokenizer", help="Output directory for tokenizer")
parser.add_argument("--vocab_size", type=int, default=32000, help="Vocabulary size")
parser.add_argument("--min_freq", type=int, default=2, help="Minimum token frequency")
parser.add_argument("--max_texts", type=int, help="Maximum number of texts to use for training")
parser.add_argument("--test_file", help="Test file for analyzing tokenization")
parser.add_argument("--load_tokenizer", help="Load existing tokenizer from directory")
args = parser.parse_args()
default_input_file = "/kaggle/input/gpt-based-slm-dataset/slm_training_complete.jsonl"
default_text_file = "/kaggle/working/text_data/training_data_chat.txt"
if not args.input_files and not args.load_tokenizer:
if os.path.exists(default_input_file):
args.input_files = [default_input_file]
print(f"No arguments provided, using default input file: {default_input_file}")
elif os.path.exists(default_text_file):
args.input_files = [default_text_file]
print(f"No arguments provided, using default text file: {default_text_file}")
else:
parser.error("No input files or tokenizer directory provided, and default files not found. "
"Please specify --input_files or --load_tokenizer.")
if args.load_tokenizer:
tokenizer = TechnicalTokenizer()
tokenizer.load(args.load_tokenizer)
if args.test_file:
print(f"\nTesting on {args.test_file}")
dataset = ConversationDataset(args.test_file, tokenizer)
tokenized, stats = dataset.get_tokenized_conversations(include_stats=True)
print(f"Dataset statistics:")
print(f" Total conversations: {len(tokenized)}")
print(f" Total tokens: {stats['total_tokens']:,}")
print(f" Average tokens per conversation: {stats['avg_length']:.1f}")
print(f" Conversations truncated: {stats['truncated']}")
else:
tokenizer = train_tokenizer_from_files(
file_paths=args.input_files,
vocab_size=args.vocab_size,
min_freq=args.min_freq,
output_dir=args.output_dir,
max_texts=args.max_texts
)
if args.test_file:
print(f"\nTesting on {args.test_file}")
dataset = ConversationDataset(args.test_file, tokenizer)
tokenized, stats = dataset.get_tokenized_conversations(include_stats=True)
print(f"Dataset statistics:")
print(f" Total conversations: {len(tokenized)}")
print(f" Total tokens: {stats['total_tokens']:,}")
print(f" Average tokens per conversation: {stats['avg_length']:.1f}")
print(f" Conversations truncated: {stats['truncated']}")
if __name__ == "__main__":
main()