|
|
|
|
|
""" |
|
|
Maaza Nano-Orchestrator 9.6M - Custom BPE Tokenizer |
|
|
Train a tool-focused tokenizer with 8k vocab. |
|
|
|
|
|
Key goal: Tool names become single tokens (maaza_extract_json = 1 token, not 5) |
|
|
""" |
|
|
|
|
|
import json |
|
|
import re |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Optional, Tuple |
|
|
from collections import Counter |
|
|
import argparse |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SPECIAL_TOKENS = [ |
|
|
"<|pad|>", |
|
|
"<|unk|>", |
|
|
"<|bos|>", |
|
|
"<|eos|>", |
|
|
"<|tool_start|>", |
|
|
"<|tool_end|>", |
|
|
"<|param_start|>", |
|
|
"<|param_end|>", |
|
|
"<|user|>", |
|
|
"<|assistant|>", |
|
|
"<|system|>", |
|
|
] |
|
|
|
|
|
|
|
|
TOOL_TOKENS = [ |
|
|
|
|
|
"maaza_extract_json", |
|
|
"mcpbodega_deploy", |
|
|
"mcpbodega_list", |
|
|
"doom_mcp", |
|
|
"bitchat_send", |
|
|
"crypto_lookup", |
|
|
"scratchpad_mcp", |
|
|
"voice_mcp", |
|
|
|
|
|
"web_search", |
|
|
"web_fetch", |
|
|
"puppeteer_navigate", |
|
|
"puppeteer_click", |
|
|
"puppeteer_screenshot", |
|
|
"puppeteer_extract", |
|
|
|
|
|
"file_read", |
|
|
"file_write", |
|
|
"database_query", |
|
|
"csv_parse", |
|
|
"json_validate", |
|
|
"image_caption", |
|
|
|
|
|
"code_execute_python", |
|
|
"code_execute_js", |
|
|
"calculator", |
|
|
"regex_match", |
|
|
"shell_command", |
|
|
|
|
|
"weather_lookup", |
|
|
"stock_lookup", |
|
|
"news_fetch", |
|
|
"email_send", |
|
|
"calendar_add", |
|
|
|
|
|
"mcpbodega_chat", |
|
|
"health_check", |
|
|
"slmbench_query", |
|
|
"slack_send", |
|
|
"github_issue", |
|
|
"cyclecore_terminal", |
|
|
] |
|
|
|
|
|
|
|
|
JSON_TOKENS = [ |
|
|
'{"tool"', |
|
|
'"params"', |
|
|
'"action"', |
|
|
'"retry"', |
|
|
'"fallback"', |
|
|
"true", |
|
|
"false", |
|
|
"null", |
|
|
] |
|
|
|
|
|
|
|
|
RECOVERY_TOKENS = [ |
|
|
"retry", |
|
|
"fallback", |
|
|
"timeout", |
|
|
"rate_limit", |
|
|
"unavailable", |
|
|
"max_retries", |
|
|
"backoff", |
|
|
"exponential", |
|
|
"alternative", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BPETokenizer: |
|
|
"""Custom BPE tokenizer optimized for tool routing.""" |
|
|
|
|
|
def __init__(self, vocab_size: int = 8000): |
|
|
self.vocab_size = vocab_size |
|
|
self.vocab: Dict[str, int] = {} |
|
|
self.inverse_vocab: Dict[int, str] = {} |
|
|
self.merges: List[Tuple[str, str]] = [] |
|
|
|
|
|
|
|
|
self._init_special_tokens() |
|
|
|
|
|
def _init_special_tokens(self): |
|
|
"""Initialize vocabulary with special tokens.""" |
|
|
idx = 0 |
|
|
|
|
|
|
|
|
for token in SPECIAL_TOKENS: |
|
|
self.vocab[token] = idx |
|
|
self.inverse_vocab[idx] = token |
|
|
idx += 1 |
|
|
|
|
|
|
|
|
for token in TOOL_TOKENS: |
|
|
self.vocab[token] = idx |
|
|
self.inverse_vocab[idx] = token |
|
|
idx += 1 |
|
|
|
|
|
|
|
|
for token in JSON_TOKENS: |
|
|
self.vocab[token] = idx |
|
|
self.inverse_vocab[idx] = token |
|
|
idx += 1 |
|
|
|
|
|
|
|
|
for token in RECOVERY_TOKENS: |
|
|
self.vocab[token] = idx |
|
|
self.inverse_vocab[idx] = token |
|
|
idx += 1 |
|
|
|
|
|
|
|
|
for i in range(256): |
|
|
char = chr(i) if i >= 32 and i < 127 else f"<0x{i:02X}>" |
|
|
if char not in self.vocab: |
|
|
self.vocab[char] = idx |
|
|
self.inverse_vocab[idx] = char |
|
|
idx += 1 |
|
|
|
|
|
self.base_vocab_size = idx |
|
|
|
|
|
def _get_pairs(self, word: List[str]) -> Counter: |
|
|
"""Get all adjacent pairs in word.""" |
|
|
pairs = Counter() |
|
|
for i in range(len(word) - 1): |
|
|
pairs[(word[i], word[i + 1])] += 1 |
|
|
return pairs |
|
|
|
|
|
def _merge_pair(self, pair: Tuple[str, str], word: List[str]) -> List[str]: |
|
|
"""Merge a specific pair in the word.""" |
|
|
new_word = [] |
|
|
i = 0 |
|
|
while i < len(word): |
|
|
if i < len(word) - 1 and word[i] == pair[0] and word[i + 1] == pair[1]: |
|
|
new_word.append(pair[0] + pair[1]) |
|
|
i += 2 |
|
|
else: |
|
|
new_word.append(word[i]) |
|
|
i += 1 |
|
|
return new_word |
|
|
|
|
|
def _tokenize_word(self, word: str) -> List[str]: |
|
|
"""Tokenize a single word to characters.""" |
|
|
|
|
|
if word in self.vocab: |
|
|
return [word] |
|
|
|
|
|
|
|
|
for tool in TOOL_TOKENS: |
|
|
if tool in word: |
|
|
parts = word.split(tool) |
|
|
result = [] |
|
|
for i, part in enumerate(parts): |
|
|
if part: |
|
|
result.extend(list(part)) |
|
|
if i < len(parts) - 1: |
|
|
result.append(tool) |
|
|
return result |
|
|
|
|
|
return list(word) |
|
|
|
|
|
def train(self, texts: List[str], verbose: bool = True): |
|
|
"""Train BPE on a corpus of texts.""" |
|
|
if verbose: |
|
|
print(f"Training BPE tokenizer (target vocab: {self.vocab_size})") |
|
|
print(f" Base vocab size: {self.base_vocab_size}") |
|
|
|
|
|
|
|
|
word_freqs = Counter() |
|
|
for text in texts: |
|
|
|
|
|
words = re.findall(r'\w+|[^\w\s]', text.lower()) |
|
|
word_freqs.update(words) |
|
|
|
|
|
|
|
|
splits = {} |
|
|
for word, freq in word_freqs.items(): |
|
|
splits[word] = (self._tokenize_word(word), freq) |
|
|
|
|
|
|
|
|
num_merges = self.vocab_size - len(self.vocab) |
|
|
if verbose: |
|
|
print(f" Performing {num_merges} merges...") |
|
|
|
|
|
for merge_idx in range(num_merges): |
|
|
|
|
|
pair_freqs = Counter() |
|
|
for word, (split, freq) in splits.items(): |
|
|
pairs = self._get_pairs(split) |
|
|
for pair, count in pairs.items(): |
|
|
pair_freqs[pair] += count * freq |
|
|
|
|
|
if not pair_freqs: |
|
|
break |
|
|
|
|
|
|
|
|
best_pair = pair_freqs.most_common(1)[0][0] |
|
|
self.merges.append(best_pair) |
|
|
|
|
|
|
|
|
merged = best_pair[0] + best_pair[1] |
|
|
if merged not in self.vocab: |
|
|
idx = len(self.vocab) |
|
|
self.vocab[merged] = idx |
|
|
self.inverse_vocab[idx] = merged |
|
|
|
|
|
|
|
|
for word in splits: |
|
|
split, freq = splits[word] |
|
|
splits[word] = (self._merge_pair(best_pair, split), freq) |
|
|
|
|
|
if verbose and (merge_idx + 1) % 500 == 0: |
|
|
print(f" Merge {merge_idx + 1}: '{best_pair[0]}' + '{best_pair[1]}' -> '{merged}'") |
|
|
|
|
|
if verbose: |
|
|
print(f" Final vocab size: {len(self.vocab)}") |
|
|
|
|
|
def encode(self, text: str) -> List[int]: |
|
|
"""Encode text to token IDs.""" |
|
|
tokens = [] |
|
|
|
|
|
|
|
|
|
|
|
special_pattern = '|'.join(re.escape(t) for t in SPECIAL_TOKENS) |
|
|
tool_pattern = '|'.join(re.escape(t) for t in TOOL_TOKENS) |
|
|
combined_pattern = f'({special_pattern}|{tool_pattern})' |
|
|
|
|
|
|
|
|
parts = re.split(combined_pattern, text) |
|
|
|
|
|
for part in parts: |
|
|
if not part: |
|
|
continue |
|
|
|
|
|
|
|
|
if part in self.vocab: |
|
|
tokens.append(self.vocab[part]) |
|
|
continue |
|
|
|
|
|
|
|
|
words = re.findall(r'\w+|[^\w\s]|\s+', part) |
|
|
|
|
|
for word in words: |
|
|
|
|
|
if word in self.vocab: |
|
|
tokens.append(self.vocab[word]) |
|
|
continue |
|
|
|
|
|
|
|
|
word_lower = word.lower() |
|
|
if word_lower in self.vocab: |
|
|
tokens.append(self.vocab[word_lower]) |
|
|
continue |
|
|
|
|
|
|
|
|
found_tool = False |
|
|
for tool in TOOL_TOKENS: |
|
|
if tool in word_lower: |
|
|
parts_inner = word_lower.split(tool) |
|
|
for i, p in enumerate(parts_inner): |
|
|
if p: |
|
|
tokens.extend(self._encode_subword(p)) |
|
|
if i < len(parts_inner) - 1: |
|
|
tokens.append(self.vocab[tool]) |
|
|
found_tool = True |
|
|
break |
|
|
|
|
|
if found_tool: |
|
|
continue |
|
|
|
|
|
|
|
|
tokens.extend(self._encode_subword(word_lower)) |
|
|
|
|
|
return tokens |
|
|
|
|
|
def _encode_subword(self, word: str) -> List[int]: |
|
|
"""Apply BPE merges to encode a subword.""" |
|
|
if not word: |
|
|
return [] |
|
|
|
|
|
if word in self.vocab: |
|
|
return [self.vocab[word]] |
|
|
|
|
|
|
|
|
word_tokens = list(word) |
|
|
|
|
|
|
|
|
for pair in self.merges: |
|
|
i = 0 |
|
|
while i < len(word_tokens) - 1: |
|
|
if word_tokens[i] == pair[0] and word_tokens[i + 1] == pair[1]: |
|
|
word_tokens = word_tokens[:i] + [pair[0] + pair[1]] + word_tokens[i + 2:] |
|
|
else: |
|
|
i += 1 |
|
|
|
|
|
|
|
|
ids = [] |
|
|
for token in word_tokens: |
|
|
if token in self.vocab: |
|
|
ids.append(self.vocab[token]) |
|
|
else: |
|
|
|
|
|
ids.append(self.vocab["<|unk|>"]) |
|
|
|
|
|
return ids |
|
|
|
|
|
def decode(self, ids: List[int]) -> str: |
|
|
"""Decode token IDs back to text.""" |
|
|
tokens = [self.inverse_vocab.get(i, "<|unk|>") for i in ids] |
|
|
text = "".join(tokens) |
|
|
|
|
|
|
|
|
for special in SPECIAL_TOKENS: |
|
|
text = text.replace(special, "") |
|
|
|
|
|
return text |
|
|
|
|
|
def save(self, path: str): |
|
|
"""Save tokenizer to file.""" |
|
|
data = { |
|
|
"vocab_size": self.vocab_size, |
|
|
"vocab": self.vocab, |
|
|
"merges": self.merges, |
|
|
"special_tokens": SPECIAL_TOKENS, |
|
|
"tool_tokens": TOOL_TOKENS, |
|
|
} |
|
|
with open(path, "w") as f: |
|
|
json.dump(data, f, indent=2) |
|
|
print(f"Tokenizer saved to {path}") |
|
|
|
|
|
@classmethod |
|
|
def load(cls, path: str) -> "BPETokenizer": |
|
|
"""Load tokenizer from file.""" |
|
|
with open(path) as f: |
|
|
data = json.load(f) |
|
|
|
|
|
tokenizer = cls(vocab_size=data["vocab_size"]) |
|
|
tokenizer.vocab = data["vocab"] |
|
|
tokenizer.inverse_vocab = {int(v): k for k, v in data["vocab"].items()} |
|
|
tokenizer.merges = [tuple(m) for m in data["merges"]] |
|
|
|
|
|
return tokenizer |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.vocab) |
|
|
|
|
|
|
|
|
def train_from_dataset(dataset_path: str, output_path: str = "tokenizer.json", vocab_size: int = 8000): |
|
|
"""Train tokenizer from dataset file.""" |
|
|
print(f"Loading dataset from {dataset_path}") |
|
|
|
|
|
texts = [] |
|
|
with open(dataset_path) as f: |
|
|
for line in f: |
|
|
data = json.loads(line) |
|
|
texts.append(data["prompt"]) |
|
|
texts.append(json.dumps(data["tool_calls"])) |
|
|
|
|
|
print(f"Loaded {len(texts)} text samples") |
|
|
|
|
|
tokenizer = BPETokenizer(vocab_size=vocab_size) |
|
|
tokenizer.train(texts, verbose=True) |
|
|
tokenizer.save(output_path) |
|
|
|
|
|
|
|
|
print("\n=== Tokenization Tests ===") |
|
|
test_cases = [ |
|
|
"extract the invoice details", |
|
|
'{"tool": "maaza_extract_json", "params": {"text": "test"}}', |
|
|
"puppeteer_navigate to google.com", |
|
|
"The crypto_lookup tool failed with timeout", |
|
|
"retry with exponential backoff", |
|
|
] |
|
|
|
|
|
for text in test_cases: |
|
|
ids = tokenizer.encode(text) |
|
|
decoded = tokenizer.decode(ids) |
|
|
print(f"\nInput: '{text}'") |
|
|
print(f"Tokens: {ids}") |
|
|
print(f"Decoded: '{decoded}'") |
|
|
print(f"Length: {len(ids)} tokens") |
|
|
|
|
|
|
|
|
print("\n=== Tool Token Verification ===") |
|
|
for tool in TOOL_TOKENS[:5]: |
|
|
ids = tokenizer.encode(tool) |
|
|
if len(ids) == 1: |
|
|
print(f"✓ {tool} = single token (ID: {ids[0]})") |
|
|
else: |
|
|
print(f"✗ {tool} = {len(ids)} tokens: {ids}") |
|
|
|
|
|
return tokenizer |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Train custom BPE tokenizer") |
|
|
parser.add_argument("--input", required=True, help="Input dataset (JSONL)") |
|
|
parser.add_argument("--output", default="tokenizer.json", help="Output path") |
|
|
parser.add_argument("--vocab-size", type=int, default=8000, help="Vocabulary size") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
train_from_dataset( |
|
|
dataset_path=args.input, |
|
|
output_path=args.output, |
|
|
vocab_size=args.vocab_size |
|
|
) |
|
|
|
|
|
print(f"\n✓ Tokenizer trained and saved to {args.output}") |
|
|
print(f"Next step: python model.py") |
|
|
|