Girinath11 commited on
Commit
02aea9b
·
verified ·
1 Parent(s): 5a4d89b

Update custom_tokenizer.py

Browse files
Files changed (1) hide show
  1. custom_tokenizer.py +1029 -393
custom_tokenizer.py CHANGED
@@ -2,22 +2,240 @@ import os
2
  import json
3
  import pickle
4
  import argparse
5
- from collections import Counter, defaultdict
6
- from typing import List, Dict, Set, Optional, Tuple
 
 
 
 
7
  import re
8
  import unicodedata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class TechnicalTokenizer:
10
  """
11
- Custom tokenizer optimized for technical content and conversations
12
- """
13
- def __init__(self, vocab_size: int = 32000, min_freq: int = 2):
14
- self.vocab_size = vocab_size
15
- self.min_freq = min_freq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  self.special_tokens = {
17
- '<pad>': 0,
18
- '<unk>': 1,
19
- '<bos>': 2,
20
- '<eos>': 3,
21
  '<system>': 4,
22
  '<user>': 5,
23
  '<assistant>': 6,
@@ -27,411 +245,829 @@ class TechnicalTokenizer:
27
  '<|code|>': 10,
28
  '<|/code|>': 11,
29
  '<|math|>': 12,
30
- '<|/math|>': 13
31
- }
32
- self.vocab = {}
33
- self.id_to_token = {}
34
- self.token_frequencies = Counter()
35
- self.bpe_merges = []
36
- self.bpe_cache = {}
37
- self.code_pattern = re.compile(r'```[\s\S]*?```|`[^`]+`')
38
- self.url_pattern = re.compile(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')
39
- self.email_pattern = re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b')
40
- self.number_pattern = re.compile(r'\b\d+\.?\d*\b')
41
- self.technical_terms = {
42
- 'function', 'variable', 'array', 'object', 'class', 'method', 'parameter',
43
- 'return', 'import', 'export', 'async', 'await', 'promise', 'callback',
44
- 'algorithm', 'datatype', 'boolean', 'integer', 'string', 'float',
45
- 'javascript', 'python', 'java', 'cpp', 'html', 'css', 'sql',
46
- 'api', 'json', 'xml', 'http', 'https', 'rest', 'graphql',
47
- 'equation', 'formula', 'theorem', 'proof', 'hypothesis',
48
- 'derivative', 'integral', 'matrix', 'vector', 'polynomial',
49
- 'probability', 'statistics', 'correlation', 'regression',
50
- 'neural', 'network', 'model', 'training', 'validation', 'test',
51
- 'accuracy', 'precision', 'recall', 'f1score', 'loss', 'gradient',
52
- 'backpropagation', 'forward', 'layer', 'neuron', 'weight', 'bias',
53
- 'transformer', 'attention', 'embedding', 'tokenization',
54
- 'database', 'server', 'client', 'protocol', 'encryption', 'security',
55
- 'authentication', 'authorization', 'deployment', 'docker', 'kubernetes',
56
- 'microservice', 'architecture', 'scalability', 'performance'
57
- }
58
- self._init_vocab()
59
- def _init_vocab(self):
60
  self.vocab = self.special_tokens.copy()
61
- self.id_to_token = {v: k for k, v in self.special_tokens.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def normalize_text(self, text: str) -> str:
63
- text = re.sub(r'\r\n|\r', '\n', text)
64
- text = re.sub(r'\t', '<|tab|>', text)
65
- text = unicodedata.normalize('NFKC', text)
66
- code_blocks = []
67
- def replace_code(match):
68
- code_blocks.append(match.group())
69
- return f'<|code|>CODE_BLOCK_{len(code_blocks)-1}<|/code|>'
70
- text = self.code_pattern.sub(replace_code, text)
71
- text = self.url_pattern.sub('<URL>', text)
72
- text = self.email_pattern.sub('<EMAIL>', text)
73
- for i, code_block in enumerate(code_blocks):
74
- text = text.replace(f'<|code|>CODE_BLOCK_{i}<|/code|>', code_block)
75
- return text
 
 
 
 
 
76
  def pre_tokenize(self, text: str) -> List[str]:
77
- text = self.normalize_text(text)
78
- text = re.sub(r'<\|system\|>', ' <system> ', text)
79
- text = re.sub(r'<\|user\|>', ' <user> ', text)
80
- text = re.sub(r'<\|assistant\|>', ' <assistant> ', text)
81
- text = re.sub(r'<\|endoftext\|>', ' <|endoftext|> ', text)
82
- tokens = re.findall(r'''
83
- <[^>]+>| # Special tokens
84
- \b\w+@\w+\.\w+\b| # Email-like patterns
85
- https?://\S+| # URLs
86
- ```[\s\S]*?```| # Code blocks
87
- `[^`]+`| # Inline code
88
- \b\d+\.?\d*\b| # Numbers
89
- \b[a-zA-Z]+(?:'[a-z]*)?| # Words with optional apostrophes
90
- [^\w\s] # Punctuation
91
- ''', text, re.VERBOSE)
92
- return [token.strip() for token in tokens if token.strip()]
93
- def get_pairs(self, word_freqs: Dict[Tuple[str, ...], int]) -> Counter:
94
- pairs = Counter()
95
- for word, freq in word_freqs.items():
96
- if len(word) < 2:
97
- continue
98
- for i in range(len(word) - 1):
99
- pair = (word[i], word[i + 1])
100
- pairs[pair] += freq
101
- return pairs
102
- def merge_symbols(self, pair: Tuple[str, str], word_freqs: Dict[Tuple[str, ...], int]) -> Dict[Tuple[str, ...], int]:
103
- new_word_freqs = {}
104
- bigram = pair
105
- for word, freq in word_freqs.items():
106
- new_word = []
107
- i = 0
108
- while i < len(word):
109
- if i < len(word) - 1 and (word[i], word[i + 1]) == bigram:
110
- new_word.append(word[i] + word[i + 1])
111
- i += 2
112
- else:
113
- new_word.append(word[i])
114
- i += 1
115
- new_word_freqs[tuple(new_word)] = freq
116
- return new_word_freqs
117
- def train_bpe(self, texts: List[str]) -> None:
118
- print("Training BPE tokenizer...")
119
- word_freqs = Counter()
120
- for i, text in enumerate(texts):
121
- if i % 10000 == 0:
122
- print(f"Processing text {i}/{len(texts)}")
123
- tokens = self.pre_tokenize(text)
124
- for token in tokens:
125
- char_seq = tuple(token)
126
- if len(char_seq) > 0:
127
- word_freqs[char_seq] += 1
128
- print(f"Found {len(word_freqs)} unique word patterns")
129
- word_freqs = {word: freq for word, freq in word_freqs.items() if freq >= self.min_freq}
130
- for term in self.technical_terms:
131
- if (term,) in word_freqs:
132
- word_freqs[(term,)] *= 10
 
 
 
 
 
 
 
 
 
 
 
 
133
  all_chars = set()
134
- for word in word_freqs:
135
- all_chars.update(word)
 
 
136
  for char in sorted(all_chars):
137
  if char not in self.vocab:
138
- self.vocab[char] = len(self.vocab)
139
- self.id_to_token[len(self.id_to_token)] = char
140
- target_vocab_size = self.vocab_size - len(self.special_tokens)
141
- num_merges = target_vocab_size - len(self.vocab)
142
- for i in range(num_merges):
143
- if i % 1000 == 0:
144
- print(f"BPE merge {i}/{num_merges}")
145
- pairs = self.get_pairs(word_freqs)
146
- if not pairs:
147
- break
148
- best_pair = pairs.most_common(1)[0][0]
149
- word_freqs = self.merge_symbols(best_pair, word_freqs)
150
- merged_token = best_pair[0] + best_pair[1]
151
- if merged_token not in self.vocab:
152
- self.vocab[merged_token] = len(self.vocab)
153
- self.id_to_token[len(self.id_to_token)] = merged_token
154
- self.bpe_merges.append(best_pair)
155
- print(f"BPE training complete. Final vocabulary size: {len(self.vocab)}")
156
- for word, freq in word_freqs.items():
157
- for token in word:
158
- self.token_frequencies[token] += freq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  def apply_bpe(self, word: str) -> List[str]:
160
- if word in self.bpe_cache:
161
- return self.bpe_cache[word]
162
- tokens = list(word)
163
- for merge in self.bpe_merges:
164
- i = 0
165
- while i < len(tokens) - 1:
166
- if tokens[i] == merge[0] and tokens[i + 1] == merge[1]:
167
- tokens = tokens[:i] + [merge[0] + merge[1]] + tokens[i + 2:]
168
- else:
169
- i += 1
170
- self.bpe_cache[word] = tokens
171
- return tokens
 
 
 
 
 
172
  def tokenize(self, text: str) -> List[str]:
173
- pre_tokens = self.pre_tokenize(text)
174
- final_tokens = []
175
- for token in pre_tokens:
176
- if token in self.special_tokens or token in self.vocab:
177
- final_tokens.append(token)
178
- else:
179
- bpe_tokens = self.apply_bpe(token)
180
- final_tokens.extend(bpe_tokens)
181
- return final_tokens
182
- def encode_ids(self, text: str, add_special_tokens: bool = True) -> List[int]:
 
 
 
 
 
 
 
 
 
 
 
 
183
  tokens = self.tokenize(text)
 
184
  if add_special_tokens:
185
- tokens = ['<bos>'] + tokens + ['<eos>']
 
186
  ids = []
 
 
187
  for token in tokens:
188
- ids.append(self.vocab.get(token, self.vocab['<unk>']))
189
- return ids
190
- def decode_ids(self, ids: List[int], skip_special_tokens: bool = True) -> str:
 
 
 
 
 
 
 
191
  tokens = []
192
- for id in ids:
193
- token = self.id_to_token.get(id, '<unk>')
 
 
 
 
 
 
 
 
194
  if skip_special_tokens and token in self.special_tokens:
195
  continue
196
- tokens.append(token)
 
 
 
197
  text = ''.join(tokens)
198
  text = text.replace('<|tab|>', '\t')
199
  text = text.replace('<|newline|>', '\n')
200
- return text
201
- def save(self, save_dir: str):
202
- os.makedirs(save_dir, exist_ok=True)
203
- with open(os.path.join(save_dir, 'vocab.json'), 'w', encoding='utf-8') as f:
204
- json.dump(self.vocab, f, indent=2, ensure_ascii=False)
205
- with open(os.path.join(save_dir, 'merges.txt'), 'w', encoding='utf-8') as f:
206
- for merge in self.bpe_merges:
207
- f.write(f"{merge[0]} {merge[1]}\n")
208
- config = {
209
- 'vocab_size': self.vocab_size,
210
- 'min_freq': self.min_freq,
211
- 'special_tokens': self.special_tokens,
212
- 'technical_terms': list(self.technical_terms)
213
- }
214
- with open(os.path.join(save_dir, 'tokenizer_config.json'), 'w', encoding='utf-8') as f:
215
- json.dump(config, f, indent=2, ensure_ascii=False)
216
- with open(os.path.join(save_dir, 'token_frequencies.pkl'), 'wb') as f:
217
- pickle.dump(dict(self.token_frequencies), f)
218
- print(f"Tokenizer saved to {save_dir}")
219
- def load(self, save_dir: str):
220
- with open(os.path.join(save_dir, 'vocab.json'), 'r', encoding='utf-8') as f:
221
- self.vocab = json.load(f)
222
- self.id_to_token = {v: k for k, v in self.vocab.items()}
223
- with open(os.path.join(save_dir, 'merges.txt'), 'r', encoding='utf-8') as f:
224
- self.bpe_merges = [tuple(line.strip().split()) for line in f if line.strip()]
225
- config_file = os.path.join(save_dir, 'tokenizer_config.json')
226
- if os.path.exists(config_file):
227
- with open(config_file, 'r', encoding='utf-8') as f:
228
- config = json.load(f)
229
- self.vocab_size = config.get('vocab_size', self.vocab_size)
230
- self.min_freq = config.get('min_freq', self.min_freq)
231
- if 'technical_terms' in config:
232
- self.technical_terms = set(config['technical_terms'])
233
- freq_file = os.path.join(save_dir, 'token_frequencies.pkl')
234
- if os.path.exists(freq_file):
235
- with open(freq_file, 'rb') as f:
236
- self.token_frequencies = Counter(pickle.load(f))
237
- self.bpe_cache = {}
238
- print(f"Tokenizer loaded from {save_dir}")
239
- print(f"Vocabulary size: {len(self.vocab)}")
240
- print(f"Number of BPE merges: {len(self.bpe_merges)}")
241
  def get_vocab_size(self) -> int:
242
- return len(self.vocab)
243
- def get_token_frequency(self, token: str) -> int:
244
- return self.token_frequencies.get(token, 0)
245
- def analyze_tokenization(self, text: str):
246
- tokens = self.tokenize(text)
247
- ids = self.encode_ids(text, add_special_tokens=False)
248
- print(f"Original text: {text}")
249
- print(f"Tokens: {tokens}")
250
- print(f"Token IDs: {ids}")
251
- print(f"Number of tokens: {len(tokens)}")
252
- print(f"Compression ratio: {len(text.split())/len(tokens):.2f}")
253
- return tokens, ids
254
- class ConversationDataset:
255
- """Dataset class for handling conversation data with the custom tokenizer"""
256
- def __init__(self, data_file: str, tokenizer: TechnicalTokenizer, max_length: int = 512):
257
- self.data_file = data_file
258
- self.tokenizer = tokenizer
259
- self.max_length = max_length
260
- self.conversations = []
261
- self.load_conversations()
262
- def load_conversations(self):
263
- print(f"Loading conversations from {self.data_file}")
264
- if self.data_file.endswith('.jsonl'):
265
- self.load_jsonl()
266
- else:
267
- self.load_text()
268
- print(f"Loaded {len(self.conversations)} conversations")
269
- def load_jsonl(self):
270
- with open(self.data_file, 'r', encoding='utf-8') as f:
271
- for line in f:
272
- try:
273
- conv = json.loads(line.strip())
274
- messages = conv.get("messages", [])
275
- if not messages:
276
- continue
277
- text_parts = []
278
- for msg in messages:
279
- role = msg.get("role", "")
280
- content = msg.get("content", "").strip()
281
- if not content:
282
- continue
283
- if role == "system":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  continue
285
- elif role == "user":
286
- text_parts.append(f"<user> {content}")
287
- elif role == "assistant":
288
- text_parts.append(f"<assistant> {content}")
289
-
290
- if len(text_parts) >= 2:
291
- conversation_text = " ".join(text_parts) + " <|endoftext|>"
292
- self.conversations.append(conversation_text)
293
- except json.JSONDecodeError:
294
- continue
295
- def load_text(self):
296
- with open(self.data_file, 'r', encoding='utf-8') as f:
297
- content = f.read()
298
- conversations = content.split('<|endoftext|>\n')
299
- for conv in conversations:
300
- conv = conv.strip()
301
- if conv:
302
- self.conversations.append(conv + " <|endoftext|>")
303
- def get_tokenized_conversations(self, include_stats=False):
304
- tokenized = []
305
- stats = {'total_tokens': 0, 'truncated': 0, 'avg_length': 0}
306
- for conv in self.conversations:
307
- tokens = self.tokenizer.encode_ids(conv)
308
- if len(tokens) > self.max_length:
309
- tokens = tokens[:self.max_length]
310
- stats['truncated'] += 1
311
- tokenized.append(tokens)
312
- stats['total_tokens'] += len(tokens)
313
- if tokenized:
314
- stats['avg_length'] = stats['total_tokens'] / len(tokenized)
315
- if include_stats:
316
- return tokenized, stats
317
- return tokenized
318
- def create_training_examples(self, stride: int = None):
319
- if stride is None:
320
- stride = self.max_length // 2
321
- examples = []
322
- for conv in self.conversations:
323
- tokens = self.tokenizer.encode_ids(conv)
324
- if len(tokens) <= self.max_length:
325
- examples.append(tokens)
326
  else:
327
- for i in range(0, len(tokens), stride):
328
- window = tokens[i:i + self.max_length]
329
- if len(window) >= 32:
330
- examples.append(window)
331
- return examples
332
- def train_tokenizer_from_files(file_paths: List[str],
333
- vocab_size: int = 32000,
334
- min_freq: int = 2,
335
- output_dir: str = "tokenizer",
336
- max_texts: int = None):
337
- print(f"Training tokenizer with vocab_size={vocab_size}")
338
- print(f"Input files: {file_paths}")
339
- all_texts = []
340
- for file_path in file_paths:
341
- print(f"Loading {file_path}...")
342
- if file_path.endswith('.jsonl'):
343
- with open(file_path, 'r', encoding='utf-8') as f:
344
- for line in f:
345
- try:
346
- conv = json.loads(line.strip())
347
- messages = conv.get("messages", [])
348
- text_parts = []
349
- for msg in messages:
350
- content = msg.get("content", "").strip()
351
- if content:
352
- text_parts.append(content)
353
- if text_parts:
354
- all_texts.append(" ".join(text_parts))
355
- except json.JSONDecodeError:
356
- continue
357
- else:
358
- with open(file_path, 'r', encoding='utf-8') as f:
359
- content = f.read()
360
- chunks = content.split('\n\n')
361
- for chunk in chunks:
362
- if chunk.strip():
363
- all_texts.append(chunk.strip())
364
- print(f"Loaded {len(all_texts)} texts")
365
- if max_texts and len(all_texts) > max_texts:
366
- import random
367
- random.shuffle(all_texts)
368
- all_texts = all_texts[:max_texts]
369
- print(f"Limited to {len(all_texts)} texts")
370
- tokenizer = TechnicalTokenizer(vocab_size=vocab_size, min_freq=min_freq)
371
- tokenizer.train_bpe(all_texts)
372
- tokenizer.save(output_dir)
373
- print("\nTesting tokenization on sample texts:")
374
- test_texts = [
375
- "Hello, how can I help you with your Python programming question?",
376
- "The neural network has 3 hidden layers with ReLU activation functions.",
377
- "```python\ndef fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)\n```",
378
- "The derivative of x^2 is 2x, and the integral is (x^3)/3 + C."
379
- ]
380
- for text in test_texts:
381
- tokenizer.analyze_tokenization(text)
382
- print()
383
  return tokenizer
 
 
384
  def main():
385
- parser = argparse.ArgumentParser(description="Train custom tokenizer for technical content")
386
- parser.add_argument("--input_files", nargs='+', help="Input text/jsonl files")
387
- parser.add_argument("--output_dir", default="tokenizer", help="Output directory for tokenizer")
388
- parser.add_argument("--vocab_size", type=int, default=32000, help="Vocabulary size")
389
- parser.add_argument("--min_freq", type=int, default=2, help="Minimum token frequency")
390
- parser.add_argument("--max_texts", type=int, help="Maximum number of texts to use for training")
391
- parser.add_argument("--test_file", help="Test file for analyzing tokenization")
392
- parser.add_argument("--load_tokenizer", help="Load existing tokenizer from directory")
393
- args = parser.parse_args()
394
- default_input_file = "/kaggle/input/gpt-based-slm-dataset/slm_training_complete.jsonl"
395
- default_text_file = "/kaggle/working/text_data/training_data_chat.txt"
396
- if not args.input_files and not args.load_tokenizer:
397
- if os.path.exists(default_input_file):
398
- args.input_files = [default_input_file]
399
- print(f"No arguments provided, using default input file: {default_input_file}")
400
- elif os.path.exists(default_text_file):
401
- args.input_files = [default_text_file]
402
- print(f"No arguments provided, using default text file: {default_text_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  else:
404
- parser.error("No input files or tokenizer directory provided, and default files not found. "
405
- "Please specify --input_files or --load_tokenizer.")
406
- if args.load_tokenizer:
407
- tokenizer = TechnicalTokenizer()
408
- tokenizer.load(args.load_tokenizer)
409
- if args.test_file:
410
- print(f"\nTesting on {args.test_file}")
411
- dataset = ConversationDataset(args.test_file, tokenizer)
412
- tokenized, stats = dataset.get_tokenized_conversations(include_stats=True)
413
- print(f"Dataset statistics:")
414
- print(f" Total conversations: {len(tokenized)}")
415
- print(f" Total tokens: {stats['total_tokens']:,}")
416
- print(f" Average tokens per conversation: {stats['avg_length']:.1f}")
417
- print(f" Conversations truncated: {stats['truncated']}")
418
- else:
419
- tokenizer = train_tokenizer_from_files(
420
- file_paths=args.input_files,
421
- vocab_size=args.vocab_size,
422
- min_freq=args.min_freq,
423
- output_dir=args.output_dir,
424
- max_texts=args.max_texts
425
- )
426
- if args.test_file:
427
- print(f"\nTesting on {args.test_file}")
428
- dataset = ConversationDataset(args.test_file, tokenizer)
429
- tokenized, stats = dataset.get_tokenized_conversations(include_stats=True)
430
- print(f"Dataset statistics:")
431
- print(f" Total conversations: {len(tokenized)}")
432
- print(f" Total tokens: {stats['total_tokens']:,}")
433
- print(f" Average tokens per conversation: {stats['avg_length']:.1f}")
434
- print(f" Conversations truncated: {stats['truncated']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
  if __name__ == "__main__":
437
- main()
 
2
  import json
3
  import pickle
4
  import argparse
5
+ import logging
6
+ import threading
7
+ from collections import Counter, defaultdict, OrderedDict
8
+ from typing import List, Dict, Set, Optional, Tuple, Union, Iterator, Any
9
+ from dataclasses import dataclass, asdict
10
+ from pathlib import Path
11
  import re
12
  import unicodedata
13
+ import heapq
14
+ from functools import lru_cache
15
+ import time
16
+ from contextlib import contextmanager
17
+
18
+ # Configure logging
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
22
+ )
23
+ logger = logging.getLogger(__name__)
24
+
25
+ @dataclass
26
+ class TokenizerConfig:
27
+ """Configuration class with validation and serialization support"""
28
+
29
+ vocab_size: int = 32000
30
+ min_freq: int = 2
31
+ max_token_length: int = 256
32
+ cache_size: int = 10000
33
+ chunk_size: int = 10000
34
+
35
+ # Special tokens
36
+ pad_token: str = '<pad>'
37
+ unk_token: str = '<unk>'
38
+ bos_token: str = '<bos>'
39
+ eos_token: str = '<eos>'
40
+
41
+ # Technical domain specific
42
+ enable_code_detection: bool = True
43
+ enable_math_detection: bool = True
44
+ enable_url_detection: bool = True
45
+
46
+ def __post_init__(self):
47
+ """Validate configuration parameters"""
48
+ if self.vocab_size <= 0:
49
+ raise ValueError(f"vocab_size must be positive, got {self.vocab_size}")
50
+ if self.min_freq <= 0:
51
+ raise ValueError(f"min_freq must be positive, got {self.min_freq}")
52
+ if self.max_token_length <= 0:
53
+ raise ValueError(f"max_token_length must be positive, got {self.max_token_length}")
54
+ if self.cache_size <= 0:
55
+ raise ValueError(f"cache_size must be positive, got {self.cache_size}")
56
+
57
+ logger.info(f"TokenizerConfig validated: vocab_size={self.vocab_size}")
58
+
59
+ def save(self, path: Union[str, Path]) -> None:
60
+ """Save configuration to JSON file"""
61
+ path = Path(path)
62
+ with open(path, 'w', encoding='utf-8') as f:
63
+ json.dump(asdict(self), f, indent=2, ensure_ascii=False)
64
+ logger.info(f"Config saved to {path}")
65
+
66
+ @classmethod
67
+ def load(cls, path: Union[str, Path]) -> 'TokenizerConfig':
68
+ """Load configuration from JSON file"""
69
+ path = Path(path)
70
+ if not path.exists():
71
+ raise FileNotFoundError(f"Config file not found: {path}")
72
+
73
+ with open(path, 'r', encoding='utf-8') as f:
74
+ config_dict = json.load(f)
75
+
76
+ logger.info(f"Config loaded from {path}")
77
+ return cls(**config_dict)
78
+
79
+
80
+ class ThreadSafeLRUCache:
81
+ """Thread-safe LRU cache with size limits"""
82
+
83
+ def __init__(self, max_size: int = 10000):
84
+ self.max_size = max_size
85
+ self.cache = OrderedDict()
86
+ self.lock = threading.RLock()
87
+
88
+ def get(self, key: str) -> Optional[List[str]]:
89
+ """Get value from cache"""
90
+ with self.lock:
91
+ if key in self.cache:
92
+ # Move to end (most recently used)
93
+ value = self.cache.pop(key)
94
+ self.cache[key] = value
95
+ return value
96
+ return None
97
+
98
+ def put(self, key: str, value: List[str]) -> None:
99
+ """Add value to cache"""
100
+ with self.lock:
101
+ if key in self.cache:
102
+ self.cache.pop(key)
103
+ elif len(self.cache) >= self.max_size:
104
+ # Remove least recently used item
105
+ self.cache.popitem(last=False)
106
+
107
+ self.cache[key] = value
108
+
109
+ def clear(self) -> None:
110
+ """Clear all cache entries"""
111
+ with self.lock:
112
+ self.cache.clear()
113
+
114
+ def size(self) -> int:
115
+ """Get current cache size"""
116
+ with self.lock:
117
+ return len(self.cache)
118
+
119
+
120
+ class EfficientBPE:
121
+ """Efficient BPE implementation using priority queues"""
122
+
123
+ def __init__(self):
124
+ self.merges: List[Tuple[str, str]] = []
125
+ self.merge_ranks: Dict[Tuple[str, str], int] = {}
126
+
127
+ def train(self, word_counts: Dict[str, int], num_merges: int) -> None:
128
+ """Train BPE using efficient algorithm with priority queue"""
129
+ logger.info(f"Training BPE with {num_merges} merges")
130
+
131
+ # Convert words to character sequences
132
+ vocab = defaultdict(int)
133
+ for word, count in word_counts.items():
134
+ vocab[tuple(word)] += count
135
+
136
+ # Get all possible pairs and their frequencies
137
+ def get_pairs(vocab_dict):
138
+ pairs = defaultdict(int)
139
+ for word, freq in vocab_dict.items():
140
+ if len(word) < 2:
141
+ continue
142
+ for i in range(len(word) - 1):
143
+ pair = (word[i], word[i + 1])
144
+ pairs[pair] += freq
145
+ return pairs
146
+
147
+ for i in range(num_merges):
148
+ if i % 1000 == 0:
149
+ logger.info(f"BPE merge progress: {i}/{num_merges}")
150
+
151
+ pairs = get_pairs(vocab)
152
+ if not pairs:
153
+ logger.warning(f"No more pairs available at merge {i}")
154
+ break
155
+
156
+ # Get most frequent pair
157
+ best_pair = max(pairs.items(), key=lambda x: x[1])[0]
158
+
159
+ # Merge the best pair
160
+ new_vocab = {}
161
+ bigram = best_pair
162
+
163
+ for word, freq in vocab.items():
164
+ new_word = []
165
+ i = 0
166
+ while i < len(word):
167
+ if i < len(word) - 1 and (word[i], word[i + 1]) == bigram:
168
+ new_word.append(word[i] + word[i + 1])
169
+ i += 2
170
+ else:
171
+ new_word.append(word[i])
172
+ i += 1
173
+ new_vocab[tuple(new_word)] = freq
174
+
175
+ vocab = new_vocab
176
+ self.merges.append(best_pair)
177
+ self.merge_ranks[best_pair] = len(self.merges) - 1
178
+
179
+ logger.info(f"BPE training completed with {len(self.merges)} merges")
180
+
181
+ def apply(self, word: str) -> List[str]:
182
+ """Apply BPE merges to a word efficiently"""
183
+ if len(word) <= 1:
184
+ return list(word)
185
+
186
+ # Start with character-level tokens
187
+ word_tokens = list(word)
188
+
189
+ # Apply merges in order
190
+ for merge_pair in self.merges:
191
+ if len(word_tokens) == 1:
192
+ break
193
+
194
+ new_tokens = []
195
+ i = 0
196
+ while i < len(word_tokens):
197
+ if (i < len(word_tokens) - 1 and
198
+ word_tokens[i] == merge_pair[0] and
199
+ word_tokens[i + 1] == merge_pair[1]):
200
+ new_tokens.append(merge_pair[0] + merge_pair[1])
201
+ i += 2
202
+ else:
203
+ new_tokens.append(word_tokens[i])
204
+ i += 1
205
+
206
+ word_tokens = new_tokens
207
+
208
+ return word_tokens
209
+
210
+
211
  class TechnicalTokenizer:
212
  """
213
+ Production-quality tokenizer for technical content with:
214
+ - Efficient BPE implementation
215
+ - Thread-safe caching
216
+ - Memory-efficient streaming
217
+ - Comprehensive error handling
218
+ - Proper logging and monitoring
219
+ """
220
+
221
+ def __init__(self, config: Optional[TokenizerConfig] = None):
222
+ self.config = config or TokenizerConfig()
223
+
224
+ # Core components
225
+ self.vocab: Dict[str, int] = {}
226
+ self.id_to_token: Dict[int, str] = {}
227
+ self.token_frequencies: Counter = Counter()
228
+ self.bpe = EfficientBPE()
229
+
230
+ # Thread-safe cache
231
+ self.cache = ThreadSafeLRUCache(self.config.cache_size)
232
+
233
+ # Special tokens mapping
234
  self.special_tokens = {
235
+ self.config.pad_token: 0,
236
+ self.config.unk_token: 1,
237
+ self.config.bos_token: 2,
238
+ self.config.eos_token: 3,
239
  '<system>': 4,
240
  '<user>': 5,
241
  '<assistant>': 6,
 
245
  '<|code|>': 10,
246
  '<|/code|>': 11,
247
  '<|math|>': 12,
248
+ '<|/math|>': 13,
249
+ '<URL>': 14,
250
+ '<EMAIL>': 15,
251
+ '<NUMBER>': 16
252
+ }
253
+
254
+ # Initialize vocabulary with special tokens
255
+ self._initialize_vocab()
256
+
257
+ # Compile regex patterns for efficiency
258
+ self._compile_patterns()
259
+
260
+ # Technical terms for priority processing
261
+ self.technical_terms = self._load_technical_terms()
262
+
263
+ logger.info(f"TechnicalTokenizer initialized with vocab_size={self.config.vocab_size}")
264
+
265
+ def _initialize_vocab(self) -> None:
266
+ """Initialize vocabulary with special tokens"""
 
 
 
 
 
 
 
 
 
 
 
267
  self.vocab = self.special_tokens.copy()
268
+ self.id_to_token = {v: k for k, v in self.special_tokens.items()}
269
+
270
+ def _compile_patterns(self) -> None:
271
+ """Compile regex patterns for efficient text processing"""
272
+ patterns = []
273
+
274
+ if self.config.enable_code_detection:
275
+ patterns.extend([
276
+ r'```[\s\S]*?```', # Code blocks
277
+ r'`[^`\n]+`', # Inline code
278
+ ])
279
+
280
+ if self.config.enable_url_detection:
281
+ patterns.append(r'https?://[^\s<>"{}|\\^`[\]]+')
282
+
283
+ patterns.extend([
284
+ r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', # Email
285
+ r'<[^>]+>', # Special tokens
286
+ r'\b\d+\.?\d*\b', # Numbers
287
+ r'\b\w+(?:\'\w+)?\b', # Words with contractions
288
+ r'[^\w\s]', # Punctuation
289
+ ])
290
+
291
+ self.tokenizer_pattern = re.compile('|'.join(f'({pattern})' for pattern in patterns))
292
+
293
+ # Additional patterns for normalization
294
+ self.newline_pattern = re.compile(r'\r\n|\r')
295
+ self.tab_pattern = re.compile(r'\t')
296
+ self.multiple_space_pattern = re.compile(r'\s+')
297
+
298
+ def _load_technical_terms(self) -> Set[str]:
299
+ """Load technical terms for priority processing"""
300
+ return {
301
+ # Programming
302
+ 'function', 'variable', 'array', 'object', 'class', 'method',
303
+ 'parameter', 'return', 'import', 'export', 'async', 'await',
304
+ 'promise', 'callback', 'algorithm', 'datatype', 'boolean',
305
+
306
+ # Languages
307
+ 'python', 'javascript', 'java', 'cpp', 'rust', 'go',
308
+ 'html', 'css', 'sql', 'typescript', 'kotlin', 'swift',
309
+
310
+ # Web/API
311
+ 'api', 'rest', 'graphql', 'json', 'xml', 'http', 'https',
312
+ 'endpoint', 'request', 'response', 'authentication',
313
+
314
+ # Math/ML
315
+ 'neural', 'network', 'model', 'training', 'validation',
316
+ 'accuracy', 'precision', 'recall', 'loss', 'gradient',
317
+ 'derivative', 'integral', 'matrix', 'vector', 'tensor',
318
+ 'transformer', 'attention', 'embedding', 'tokenization',
319
+
320
+ # Infrastructure
321
+ 'docker', 'kubernetes', 'microservice', 'database',
322
+ 'server', 'client', 'deployment', 'scalability'
323
+ }
324
+
325
+ @contextmanager
326
+ def _error_context(self, operation: str):
327
+ """Context manager for consistent error handling"""
328
+ try:
329
+ yield
330
+ except Exception as e:
331
+ logger.error(f"Error in {operation}: {str(e)}")
332
+ raise
333
+
334
  def normalize_text(self, text: str) -> str:
335
+ """Normalize text with proper error handling"""
336
+ if not isinstance(text, str):
337
+ raise TypeError(f"Expected str, got {type(text)}")
338
+
339
+ with self._error_context("text normalization"):
340
+ # Basic normalization
341
+ text = self.newline_pattern.sub('\n', text)
342
+ text = self.tab_pattern.sub('<|tab|>', text)
343
+ text = unicodedata.normalize('NFKC', text)
344
+
345
+ # Handle special token markers
346
+ text = re.sub(r'<\|system\|>', ' <system> ', text)
347
+ text = re.sub(r'<\|user\|>', ' <user> ', text)
348
+ text = re.sub(r'<\|assistant\|>', ' <assistant> ', text)
349
+ text = re.sub(r'<\|endoftext\|>', ' <|endoftext|> ', text)
350
+
351
+ return text.strip()
352
+
353
  def pre_tokenize(self, text: str) -> List[str]:
354
+ """Pre-tokenize text into words and special tokens"""
355
+ if not text:
356
+ return []
357
+
358
+ with self._error_context("pre-tokenization"):
359
+ normalized_text = self.normalize_text(text)
360
+
361
+ # Find all tokens using compiled pattern
362
+ matches = self.tokenizer_pattern.findall(normalized_text)
363
+
364
+ # Flatten the match groups and filter empty strings
365
+ tokens = []
366
+ for match_groups in matches:
367
+ for group in match_groups:
368
+ if group:
369
+ tokens.append(group)
370
+ break
371
+
372
+ return [token.strip() for token in tokens if token.strip()]
373
+
374
+ def train_from_iterator(self, text_iterator: Iterator[str],
375
+ total_texts: Optional[int] = None) -> None:
376
+ """
377
+ Train tokenizer from text iterator for memory efficiency
378
+
379
+ Args:
380
+ text_iterator: Iterator yielding text strings
381
+ total_texts: Optional total count for progress tracking
382
+ """
383
+ logger.info("Starting BPE training from iterator")
384
+ start_time = time.time()
385
+
386
+ word_counts = Counter()
387
+ processed_texts = 0
388
+
389
+ # Process texts in chunks to manage memory
390
+ current_chunk = []
391
+
392
+ for text in text_iterator:
393
+ current_chunk.append(text)
394
+ processed_texts += 1
395
+
396
+ if len(current_chunk) >= self.config.chunk_size:
397
+ self._process_text_chunk(current_chunk, word_counts)
398
+ current_chunk.clear()
399
+
400
+ if processed_texts % 10000 == 0:
401
+ elapsed = time.time() - start_time
402
+ logger.info(f"Processed {processed_texts} texts in {elapsed:.1f}s")
403
+
404
+ # Process remaining texts
405
+ if current_chunk:
406
+ self._process_text_chunk(current_chunk, word_counts)
407
+
408
+ logger.info(f"Pre-processing completed: {len(word_counts)} unique words")
409
+
410
+ # Filter by frequency and boost technical terms
411
+ filtered_words = {}
412
+ for word, count in word_counts.items():
413
+ if count >= self.config.min_freq:
414
+ # Boost technical terms
415
+ if word.lower() in self.technical_terms:
416
+ count *= 5
417
+ filtered_words[word] = count
418
+
419
+ logger.info(f"After filtering: {len(filtered_words)} words")
420
+
421
+ # Build character vocabulary
422
  all_chars = set()
423
+ for word in filtered_words:
424
+ all_chars.update(word)
425
+
426
+ # Add characters to vocabulary
427
  for char in sorted(all_chars):
428
  if char not in self.vocab:
429
+ token_id = len(self.vocab)
430
+ self.vocab[char] = token_id
431
+ self.id_to_token[token_id] = char
432
+
433
+ # Calculate number of merges needed
434
+ current_vocab_size = len(self.vocab)
435
+ target_vocab_size = self.config.vocab_size
436
+ num_merges = target_vocab_size - current_vocab_size
437
+
438
+ if num_merges > 0:
439
+ # Train BPE
440
+ self.bpe.train(filtered_words, num_merges)
441
+
442
+ # Add merged tokens to vocabulary
443
+ for merge_pair in self.bpe.merges:
444
+ merged_token = merge_pair[0] + merge_pair[1]
445
+ if merged_token not in self.vocab:
446
+ token_id = len(self.vocab)
447
+ self.vocab[merged_token] = token_id
448
+ self.id_to_token[token_id] = merged_token
449
+
450
+ # Update token frequencies
451
+ for word, count in filtered_words.items():
452
+ tokens = self.apply_bpe(word)
453
+ for token in tokens:
454
+ self.token_frequencies[token] += count
455
+
456
+ training_time = time.time() - start_time
457
+ logger.info(f"Training completed in {training_time:.1f}s")
458
+ logger.info(f"Final vocabulary size: {len(self.vocab)}")
459
+
460
+ def _process_text_chunk(self, texts: List[str], word_counts: Counter) -> None:
461
+ """Process a chunk of texts and update word counts"""
462
+ for text in texts:
463
+ try:
464
+ tokens = self.pre_tokenize(text)
465
+ for token in tokens:
466
+ if len(token) <= self.config.max_token_length:
467
+ word_counts[token] += 1
468
+ except Exception as e:
469
+ logger.warning(f"Error processing text chunk: {e}")
470
+ continue
471
+
472
  def apply_bpe(self, word: str) -> List[str]:
473
+ """Apply BPE to a word with caching"""
474
+ if not word:
475
+ return []
476
+
477
+ # Check cache first
478
+ cached_result = self.cache.get(word)
479
+ if cached_result is not None:
480
+ return cached_result
481
+
482
+ # Apply BPE
483
+ tokens = self.bpe.apply(word)
484
+
485
+ # Cache the result
486
+ self.cache.put(word, tokens)
487
+
488
+ return tokens
489
+
490
  def tokenize(self, text: str) -> List[str]:
491
+ """Tokenize text into subword tokens"""
492
+ if not text:
493
+ return []
494
+
495
+ with self._error_context("tokenization"):
496
+ pre_tokens = self.pre_tokenize(text)
497
+ final_tokens = []
498
+
499
+ for token in pre_tokens:
500
+ if token in self.special_tokens or token in self.vocab:
501
+ final_tokens.append(token)
502
+ else:
503
+ bpe_tokens = self.apply_bpe(token)
504
+ final_tokens.extend(bpe_tokens)
505
+
506
+ return final_tokens
507
+
508
+ def encode(self, text: str, add_special_tokens: bool = False) -> List[int]:
509
+ """Encode text to token IDs"""
510
+ if not isinstance(text, str):
511
+ raise TypeError(f"Expected str, got {type(text)}")
512
+
513
  tokens = self.tokenize(text)
514
+
515
  if add_special_tokens:
516
+ tokens = [self.config.bos_token] + tokens + [self.config.eos_token]
517
+
518
  ids = []
519
+ unk_id = self.vocab[self.config.unk_token]
520
+
521
  for token in tokens:
522
+ token_id = self.vocab.get(token, unk_id)
523
+ ids.append(token_id)
524
+
525
+ return ids
526
+
527
+ def decode(self, ids: List[int], skip_special_tokens: bool = False) -> str:
528
+ """Decode token IDs to text"""
529
+ if not isinstance(ids, (list, tuple)):
530
+ raise TypeError(f"Expected list or tuple, got {type(ids)}")
531
+
532
  tokens = []
533
+ for token_id in ids:
534
+ if not isinstance(token_id, int):
535
+ raise TypeError(f"Expected int token ID, got {type(token_id)}")
536
+
537
+ if token_id not in self.id_to_token:
538
+ logger.warning(f"Unknown token ID: {token_id}")
539
+ continue
540
+
541
+ token = self.id_to_token[token_id]
542
+
543
  if skip_special_tokens and token in self.special_tokens:
544
  continue
545
+
546
+ tokens.append(token)
547
+
548
+ # Join tokens and clean up
549
  text = ''.join(tokens)
550
  text = text.replace('<|tab|>', '\t')
551
  text = text.replace('<|newline|>', '\n')
552
+
553
+ return text
554
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
555
  def get_vocab_size(self) -> int:
556
+ """Get vocabulary size"""
557
+ return len(self.vocab)
558
+
559
+ def get_vocab(self) -> Dict[str, int]:
560
+ """Get vocabulary dictionary (copy for safety)"""
561
+ return self.vocab.copy()
562
+
563
+ def get_cache_info(self) -> Dict[str, int]:
564
+ """Get cache statistics"""
565
+ return {
566
+ 'size': self.cache.size(),
567
+ 'max_size': self.config.cache_size,
568
+ 'hit_rate': getattr(self.cache, 'hit_rate', 0)
569
+ }
570
+
571
+ def save(self, save_dir: Union[str, Path]) -> None:
572
+ """Save tokenizer with validation"""
573
+ save_dir = Path(save_dir)
574
+ save_dir.mkdir(parents=True, exist_ok=True)
575
+
576
+ logger.info(f"Saving tokenizer to {save_dir}")
577
+
578
+ try:
579
+ # Save configuration
580
+ self.config.save(save_dir / 'config.json')
581
+
582
+ # Save vocabulary
583
+ with open(save_dir / 'vocab.json', 'w', encoding='utf-8') as f:
584
+ json.dump(self.vocab, f, indent=2, ensure_ascii=False)
585
+
586
+ # Save BPE merges
587
+ with open(save_dir / 'merges.txt', 'w', encoding='utf-8') as f:
588
+ for merge in self.bpe.merges:
589
+ f.write(f"{merge[0]} {merge[1]}\n")
590
+
591
+ # Save token frequencies
592
+ with open(save_dir / 'frequencies.pkl', 'wb') as f:
593
+ pickle.dump(dict(self.token_frequencies), f)
594
+
595
+ # Save metadata
596
+ metadata = {
597
+ 'version': '2.0',
598
+ 'vocab_size': len(self.vocab),
599
+ 'num_merges': len(self.bpe.merges),
600
+ 'special_tokens': self.special_tokens
601
+ }
602
+
603
+ with open(save_dir / 'metadata.json', 'w', encoding='utf-8') as f:
604
+ json.dump(metadata, f, indent=2)
605
+
606
+ logger.info("Tokenizer saved successfully")
607
+
608
+ except Exception as e:
609
+ logger.error(f"Error saving tokenizer: {e}")
610
+ raise
611
+
612
+ @classmethod
613
+ def load(cls, save_dir: Union[str, Path]) -> 'TechnicalTokenizer':
614
+ """Load tokenizer from directory"""
615
+ save_dir = Path(save_dir)
616
+
617
+ if not save_dir.exists():
618
+ raise FileNotFoundError(f"Tokenizer directory not found: {save_dir}")
619
+
620
+ logger.info(f"Loading tokenizer from {save_dir}")
621
+
622
+ try:
623
+ # Load configuration
624
+ config = TokenizerConfig.load(save_dir / 'config.json')
625
+
626
+ # Create tokenizer instance
627
+ tokenizer = cls(config)
628
+
629
+ # Load vocabulary
630
+ with open(save_dir / 'vocab.json', 'r', encoding='utf-8') as f:
631
+ tokenizer.vocab = json.load(f)
632
+
633
+ tokenizer.id_to_token = {v: k for k, v in tokenizer.vocab.items()}
634
+
635
+ # Load BPE merges
636
+ merges_file = save_dir / 'merges.txt'
637
+ if merges_file.exists():
638
+ with open(merges_file, 'r', encoding='utf-8') as f:
639
+ for line in f:
640
+ line = line.strip()
641
+ if line:
642
+ parts = line.split()
643
+ if len(parts) == 2:
644
+ tokenizer.bpe.merges.append(tuple(parts))
645
+
646
+ # Rebuild merge ranks
647
+ tokenizer.bpe.merge_ranks = {
648
+ merge: i for i, merge in enumerate(tokenizer.bpe.merges)
649
+ }
650
+
651
+ # Load token frequencies
652
+ freq_file = save_dir / 'frequencies.pkl'
653
+ if freq_file.exists():
654
+ with open(freq_file, 'rb') as f:
655
+ freq_dict = pickle.load(f)
656
+ tokenizer.token_frequencies = Counter(freq_dict)
657
+
658
+ logger.info(f"Tokenizer loaded successfully")
659
+ logger.info(f"Vocabulary size: {len(tokenizer.vocab)}")
660
+ logger.info(f"Number of BPE merges: {len(tokenizer.bpe.merges)}")
661
+
662
+ return tokenizer
663
+
664
+ except Exception as e:
665
+ logger.error(f"Error loading tokenizer: {e}")
666
+ raise
667
+
668
+
669
+ def create_text_iterator(file_paths: List[Union[str, Path]],
670
+ max_texts: Optional[int] = None) -> Iterator[str]:
671
+ """Create memory-efficient text iterator from multiple files"""
672
+ processed_count = 0
673
+
674
+ for file_path in file_paths:
675
+ file_path = Path(file_path)
676
+
677
+ if not file_path.exists():
678
+ logger.warning(f"File not found: {file_path}")
679
+ continue
680
+
681
+ logger.info(f"Processing file: {file_path}")
682
+
683
+ try:
684
+ if file_path.suffix == '.jsonl':
685
+ with open(file_path, 'r', encoding='utf-8') as f:
686
+ for line_num, line in enumerate(f, 1):
687
+ try:
688
+ data = json.loads(line.strip())
689
+
690
+ if 'messages' in data:
691
+ # Conversation format
692
+ texts = []
693
+ for msg in data['messages']:
694
+ content = msg.get('content', '').strip()
695
+ if content:
696
+ texts.append(content)
697
+ if texts:
698
+ yield ' '.join(texts)
699
+ processed_count += 1
700
+
701
+ elif 'text' in data:
702
+ # Simple text format
703
+ text = data['text'].strip()
704
+ if text:
705
+ yield text
706
+ processed_count += 1
707
+
708
+ if max_texts and processed_count >= max_texts:
709
+ return
710
+
711
+ except json.JSONDecodeError as e:
712
+ logger.warning(f"JSON decode error at line {line_num} in {file_path}: {e}")
713
  continue
714
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
715
  else:
716
+ # Plain text file
717
+ with open(file_path, 'r', encoding='utf-8') as f:
718
+ content = f.read()
719
+
720
+ # Split by double newlines or other separators
721
+ chunks = re.split(r'\n\s*\n', content)
722
+
723
+ for chunk in chunks:
724
+ chunk = chunk.strip()
725
+ if chunk and len(chunk) > 50: # Skip very short chunks
726
+ yield chunk
727
+ processed_count += 1
728
+
729
+ if max_texts and processed_count >= max_texts:
730
+ return
731
+
732
+ except Exception as e:
733
+ logger.error(f"Error processing file {file_path}: {e}")
734
+ continue
735
+
736
+ logger.info(f"Total texts processed: {processed_count}")
737
+
738
+
739
+ def train_tokenizer(input_files: List[Union[str, Path]],
740
+ output_dir: Union[str, Path],
741
+ config: Optional[TokenizerConfig] = None,
742
+ max_texts: Optional[int] = None) -> TechnicalTokenizer:
743
+ """Train a new tokenizer from input files"""
744
+
745
+ config = config or TokenizerConfig()
746
+ tokenizer = TechnicalTokenizer(config)
747
+
748
+ # Create text iterator
749
+ text_iter = create_text_iterator(input_files, max_texts)
750
+
751
+ # Train tokenizer
752
+ tokenizer.train_from_iterator(text_iter)
753
+
754
+ # Save tokenizer
755
+ tokenizer.save(output_dir)
756
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
757
  return tokenizer
758
+
759
+
760
  def main():
761
+ """Main CLI interface"""
762
+ parser = argparse.ArgumentParser(
763
+ description="Production-Quality Technical Tokenizer",
764
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
765
+ )
766
+
767
+ # Input/Output
768
+ parser.add_argument('--input_files', nargs='+',
769
+ help='Input files for training')
770
+ parser.add_argument('--output_dir', default='tokenizer_output',
771
+ help='Output directory for tokenizer')
772
+ parser.add_argument('--load_from',
773
+ help='Load existing tokenizer from directory')
774
+
775
+ # Training parameters
776
+ parser.add_argument('--vocab_size', type=int, default=32000,
777
+ help='Target vocabulary size')
778
+ parser.add_argument('--min_freq', type=int, default=2,
779
+ help='Minimum token frequency')
780
+ parser.add_argument('--max_texts', type=int,
781
+ help='Maximum number of texts to process')
782
+ parser.add_argument('--cache_size', type=int, default=10000,
783
+ help='BPE cache size')
784
+
785
+ # Testing
786
+ parser.add_argument('--test_text',
787
+ help='Test text for tokenization analysis')
788
+ parser.add_argument('--benchmark', action='store_true',
789
+ help='Run performance benchmarks')
790
+
791
+ # Logging
792
+ parser.add_argument('--verbose', action='store_true',
793
+ help='Enable verbose logging')
794
+
795
+ args = parser.parse_args()
796
+
797
+ if args.verbose:
798
+ logging.getLogger().setLevel(logging.DEBUG)
799
+
800
+ try:
801
+ if args.load_from:
802
+ # Load existing tokenizer
803
+ tokenizer = TechnicalTokenizer.load(args.load_from)
804
+
805
+ if args.test_text:
806
+ print(f"\nTokenization Analysis:")
807
+ print(f"Text: {args.test_text}")
808
+ tokens = tokenizer.tokenize(args.test_text)
809
+ ids = tokenizer.encode(args.test_text)
810
+ decoded = tokenizer.decode(ids)
811
+ print(f"Tokens: {tokens}")
812
+ print(f"Token IDs: {ids}")
813
+ print(f"Decoded: {decoded}")
814
+ print(f"Token count: {len(tokens)}")
815
+ print(f"Compression ratio: {len(args.test_text.split()) / len(tokens):.2f}")
816
+
817
+ if args.benchmark:
818
+ run_benchmark(tokenizer)
819
+
820
  else:
821
+ # Train new tokenizer
822
+ if not args.input_files:
823
+ parser.error("--input_files required when not loading existing tokenizer")
824
+
825
+ # Create configuration
826
+ config = TokenizerConfig(
827
+ vocab_size=args.vocab_size,
828
+ min_freq=args.min_freq,
829
+ cache_size=args.cache_size
830
+ )
831
+
832
+ # Train tokenizer
833
+ tokenizer = train_tokenizer(
834
+ input_files=args.input_files,
835
+ output_dir=args.output_dir,
836
+ config=config,
837
+ max_texts=args.max_texts
838
+ )
839
+
840
+ # Test on sample texts
841
+ test_texts = [
842
+ "Hello, how can I help you with your Python programming question?",
843
+ "The neural network architecture uses attention mechanisms for better performance.",
844
+ "```python\ndef fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)\n```",
845
+ "The derivative of x² is 2x, and the integral is (x³)/3 + C."
846
+ ]
847
+
848
+ print("\nTokenization Analysis on Sample Texts:")
849
+ print("=" * 50)
850
+
851
+ for i, text in enumerate(test_texts, 1):
852
+ print(f"\nTest {i}:")
853
+ print(f"Text: {text}")
854
+ tokens = tokenizer.tokenize(text)
855
+ ids = tokenizer.encode(text)
856
+ print(f"Tokens ({len(tokens)}): {tokens}")
857
+ print(f"Token IDs: {ids}")
858
+ word_count = len(text.split())
859
+ compression_ratio = word_count / len(tokens) if tokens else 0
860
+ print(f"Compression ratio: {compression_ratio:.2f}")
861
+
862
+ print(f"\nTokenizer training completed!")
863
+ print(f"Vocabulary size: {tokenizer.get_vocab_size()}")
864
+ print(f"Cache info: {tokenizer.get_cache_info()}")
865
+
866
+ except Exception as e:
867
+ logger.error(f"Error in main: {e}")
868
+ if args.verbose:
869
+ import traceback
870
+ traceback.print_exc()
871
+ return 1
872
+
873
+ return 0
874
+
875
+
876
+ def run_benchmark(tokenizer: TechnicalTokenizer) -> None:
877
+ """Run performance benchmarks on the tokenizer"""
878
+ import time
879
+ import random
880
+ import string
881
+
882
+ print("\nRunning Performance Benchmarks...")
883
+ print("=" * 50)
884
+
885
+ # Generate test data
886
+ test_texts = []
887
+
888
+ # Short texts
889
+ for _ in range(1000):
890
+ length = random.randint(10, 50)
891
+ text = ' '.join(''.join(random.choices(string.ascii_lowercase, k=random.randint(3, 10)))
892
+ for _ in range(length))
893
+ test_texts.append(text)
894
+
895
+ # Medium texts
896
+ for _ in range(100):
897
+ length = random.randint(100, 500)
898
+ text = ' '.join(''.join(random.choices(string.ascii_lowercase, k=random.randint(3, 10)))
899
+ for _ in range(length))
900
+ test_texts.append(text)
901
+
902
+ # Long texts
903
+ for _ in range(10):
904
+ length = random.randint(1000, 5000)
905
+ text = ' '.join(''.join(random.choices(string.ascii_lowercase, k=random.randint(3, 10)))
906
+ for _ in range(length))
907
+ test_texts.append(text)
908
+
909
+ # Benchmark tokenization
910
+ print("Benchmarking tokenization...")
911
+ start_time = time.time()
912
+
913
+ total_tokens = 0
914
+ for text in test_texts:
915
+ tokens = tokenizer.tokenize(text)
916
+ total_tokens += len(tokens)
917
+
918
+ tokenization_time = time.time() - start_time
919
+
920
+ # Benchmark encoding
921
+ print("Benchmarking encoding...")
922
+ start_time = time.time()
923
+
924
+ all_ids = []
925
+ for text in test_texts:
926
+ ids = tokenizer.encode(text)
927
+ all_ids.append(ids)
928
+
929
+ encoding_time = time.time() - start_time
930
+
931
+ # Benchmark decoding
932
+ print("Benchmarking decoding...")
933
+ start_time = time.time()
934
+
935
+ for ids in all_ids:
936
+ decoded = tokenizer.decode(ids)
937
+
938
+ decoding_time = time.time() - start_time
939
+
940
+ # Print results
941
+ print(f"\nBenchmark Results:")
942
+ print(f"Texts processed: {len(test_texts)}")
943
+ print(f"Total tokens: {total_tokens:,}")
944
+ print(f"Tokenization time: {tokenization_time:.3f}s")
945
+ print(f"Encoding time: {encoding_time:.3f}s")
946
+ print(f"Decoding time: {decoding_time:.3f}s")
947
+ print(f"Tokenization speed: {total_tokens/tokenization_time:.0f} tokens/sec")
948
+ print(f"Cache info: {tokenizer.get_cache_info()}")
949
+
950
+
951
+ class TokenizerTester:
952
+ """Comprehensive testing utilities for the tokenizer"""
953
+
954
+ def __init__(self, tokenizer: TechnicalTokenizer):
955
+ self.tokenizer = tokenizer
956
+
957
+ def test_roundtrip_consistency(self, texts: List[str]) -> Dict[str, Any]:
958
+ """Test encode/decode roundtrip consistency"""
959
+ results = {
960
+ 'total_tests': len(texts),
961
+ 'passed': 0,
962
+ 'failed': 0,
963
+ 'failures': []
964
+ }
965
+
966
+ for i, text in enumerate(texts):
967
+ try:
968
+ # Encode then decode
969
+ ids = self.tokenizer.encode(text, add_special_tokens=False)
970
+ decoded = self.tokenizer.decode(ids, skip_special_tokens=True)
971
+
972
+ # Check if roundtrip preserves meaning (not exact match due to BPE)
973
+ original_tokens = self.tokenizer.tokenize(text)
974
+ decoded_tokens = self.tokenizer.tokenize(decoded)
975
+
976
+ if len(original_tokens) == len(decoded_tokens):
977
+ results['passed'] += 1
978
+ else:
979
+ results['failed'] += 1
980
+ results['failures'].append({
981
+ 'index': i,
982
+ 'original': text,
983
+ 'decoded': decoded,
984
+ 'original_tokens': len(original_tokens),
985
+ 'decoded_tokens': len(decoded_tokens)
986
+ })
987
+
988
+ except Exception as e:
989
+ results['failed'] += 1
990
+ results['failures'].append({
991
+ 'index': i,
992
+ 'error': str(e),
993
+ 'text': text
994
+ })
995
+
996
+ return results
997
+
998
+ def test_special_tokens(self) -> Dict[str, bool]:
999
+ """Test special token handling"""
1000
+ results = {}
1001
+
1002
+ for token_name, token_id in self.tokenizer.special_tokens.items():
1003
+ try:
1004
+ # Test encoding
1005
+ ids = self.tokenizer.encode(token_name, add_special_tokens=False)
1006
+ expected_id = self.tokenizer.vocab.get(token_name)
1007
+
1008
+ # Test decoding
1009
+ decoded = self.tokenizer.decode([token_id])
1010
+
1011
+ results[token_name] = (
1012
+ expected_id in ids and
1013
+ token_name in decoded
1014
+ )
1015
+
1016
+ except Exception:
1017
+ results[token_name] = False
1018
+
1019
+ return results
1020
+
1021
+ def test_edge_cases(self) -> Dict[str, bool]:
1022
+ """Test edge cases and error conditions"""
1023
+ tests = {
1024
+ 'empty_string': True,
1025
+ 'whitespace_only': True,
1026
+ 'very_long_text': True,
1027
+ 'unicode_text': True,
1028
+ 'special_chars': True
1029
+ }
1030
+
1031
+ try:
1032
+ # Empty string
1033
+ result = self.tokenizer.encode("")
1034
+ tests['empty_string'] = isinstance(result, list)
1035
+ except Exception:
1036
+ tests['empty_string'] = False
1037
+
1038
+ try:
1039
+ # Whitespace only
1040
+ result = self.tokenizer.encode(" \n\t ")
1041
+ tests['whitespace_only'] = isinstance(result, list)
1042
+ except Exception:
1043
+ tests['whitespace_only'] = False
1044
+
1045
+ try:
1046
+ # Very long text
1047
+ long_text = "test " * 10000
1048
+ result = self.tokenizer.encode(long_text)
1049
+ tests['very_long_text'] = isinstance(result, list)
1050
+ except Exception:
1051
+ tests['very_long_text'] = False
1052
+
1053
+ try:
1054
+ # Unicode text
1055
+ unicode_text = "Hello 世界 🌍 café naïve"
1056
+ result = self.tokenizer.encode(unicode_text)
1057
+ tests['unicode_text'] = isinstance(result, list)
1058
+ except Exception:
1059
+ tests['unicode_text'] = False
1060
+
1061
+ try:
1062
+ # Special characters
1063
+ special_text = "!@#$%^&*()_+-=[]{}|;:'\",.<>?/~`"
1064
+ result = self.tokenizer.encode(special_text)
1065
+ tests['special_chars'] = isinstance(result, list)
1066
+ except Exception:
1067
+ tests['special_chars'] = False
1068
+
1069
+ return tests
1070
+
1071
 
1072
  if __name__ == "__main__":
1073
+ exit(main())