Girinath11 commited on
Commit
613f2bb
·
verified ·
1 Parent(s): 54eced5

Create custom_tokenizer.py

Browse files
Files changed (1) hide show
  1. custom_tokenizer.py +437 -0
custom_tokenizer.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pickle
4
+ import argparse
5
+ from collections import Counter, defaultdict
6
+ from typing import List, Dict, Set, Optional, Tuple
7
+ import re
8
+ import unicodedata
9
+ class TechnicalTokenizer:
10
+ """
11
+ Custom tokenizer optimized for technical content and conversations
12
+ """
13
+ def __init__(self, vocab_size: int = 32000, min_freq: int = 2):
14
+ self.vocab_size = vocab_size
15
+ self.min_freq = min_freq
16
+ self.special_tokens = {
17
+ '<pad>': 0,
18
+ '<unk>': 1,
19
+ '<bos>': 2,
20
+ '<eos>': 3,
21
+ '<system>': 4,
22
+ '<user>': 5,
23
+ '<assistant>': 6,
24
+ '<|endoftext|>': 7,
25
+ '<|newline|>': 8,
26
+ '<|tab|>': 9,
27
+ '<|code|>': 10,
28
+ '<|/code|>': 11,
29
+ '<|math|>': 12,
30
+ '<|/math|>': 13
31
+ }
32
+ self.vocab = {}
33
+ self.id_to_token = {}
34
+ self.token_frequencies = Counter()
35
+ self.bpe_merges = []
36
+ self.bpe_cache = {}
37
+ self.code_pattern = re.compile(r'```[\s\S]*?```|`[^`]+`')
38
+ self.url_pattern = re.compile(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')
39
+ self.email_pattern = re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b')
40
+ self.number_pattern = re.compile(r'\b\d+\.?\d*\b')
41
+ self.technical_terms = {
42
+ 'function', 'variable', 'array', 'object', 'class', 'method', 'parameter',
43
+ 'return', 'import', 'export', 'async', 'await', 'promise', 'callback',
44
+ 'algorithm', 'datatype', 'boolean', 'integer', 'string', 'float',
45
+ 'javascript', 'python', 'java', 'cpp', 'html', 'css', 'sql',
46
+ 'api', 'json', 'xml', 'http', 'https', 'rest', 'graphql',
47
+ 'equation', 'formula', 'theorem', 'proof', 'hypothesis',
48
+ 'derivative', 'integral', 'matrix', 'vector', 'polynomial',
49
+ 'probability', 'statistics', 'correlation', 'regression',
50
+ 'neural', 'network', 'model', 'training', 'validation', 'test',
51
+ 'accuracy', 'precision', 'recall', 'f1score', 'loss', 'gradient',
52
+ 'backpropagation', 'forward', 'layer', 'neuron', 'weight', 'bias',
53
+ 'transformer', 'attention', 'embedding', 'tokenization',
54
+ 'database', 'server', 'client', 'protocol', 'encryption', 'security',
55
+ 'authentication', 'authorization', 'deployment', 'docker', 'kubernetes',
56
+ 'microservice', 'architecture', 'scalability', 'performance'
57
+ }
58
+ self._init_vocab()
59
+ def _init_vocab(self):
60
+ self.vocab = self.special_tokens.copy()
61
+ self.id_to_token = {v: k for k, v in self.special_tokens.items()}
62
+ def normalize_text(self, text: str) -> str:
63
+ text = re.sub(r'\r\n|\r', '\n', text)
64
+ text = re.sub(r'\t', '<|tab|>', text)
65
+ text = unicodedata.normalize('NFKC', text)
66
+ code_blocks = []
67
+ def replace_code(match):
68
+ code_blocks.append(match.group())
69
+ return f'<|code|>CODE_BLOCK_{len(code_blocks)-1}<|/code|>'
70
+ text = self.code_pattern.sub(replace_code, text)
71
+ text = self.url_pattern.sub('<URL>', text)
72
+ text = self.email_pattern.sub('<EMAIL>', text)
73
+ for i, code_block in enumerate(code_blocks):
74
+ text = text.replace(f'<|code|>CODE_BLOCK_{i}<|/code|>', code_block)
75
+ return text
76
+ def pre_tokenize(self, text: str) -> List[str]:
77
+ text = self.normalize_text(text)
78
+ text = re.sub(r'<\|system\|>', ' <system> ', text)
79
+ text = re.sub(r'<\|user\|>', ' <user> ', text)
80
+ text = re.sub(r'<\|assistant\|>', ' <assistant> ', text)
81
+ text = re.sub(r'<\|endoftext\|>', ' <|endoftext|> ', text)
82
+ tokens = re.findall(r'''
83
+ <[^>]+>| # Special tokens
84
+ \b\w+@\w+\.\w+\b| # Email-like patterns
85
+ https?://\S+| # URLs
86
+ ```[\s\S]*?```| # Code blocks
87
+ `[^`]+`| # Inline code
88
+ \b\d+\.?\d*\b| # Numbers
89
+ \b[a-zA-Z]+(?:'[a-z]*)?| # Words with optional apostrophes
90
+ [^\w\s] # Punctuation
91
+ ''', text, re.VERBOSE)
92
+ return [token.strip() for token in tokens if token.strip()]
93
+ def get_pairs(self, word_freqs: Dict[Tuple[str, ...], int]) -> Counter:
94
+ pairs = Counter()
95
+ for word, freq in word_freqs.items():
96
+ if len(word) < 2:
97
+ continue
98
+ for i in range(len(word) - 1):
99
+ pair = (word[i], word[i + 1])
100
+ pairs[pair] += freq
101
+ return pairs
102
+ def merge_symbols(self, pair: Tuple[str, str], word_freqs: Dict[Tuple[str, ...], int]) -> Dict[Tuple[str, ...], int]:
103
+ new_word_freqs = {}
104
+ bigram = pair
105
+ for word, freq in word_freqs.items():
106
+ new_word = []
107
+ i = 0
108
+ while i < len(word):
109
+ if i < len(word) - 1 and (word[i], word[i + 1]) == bigram:
110
+ new_word.append(word[i] + word[i + 1])
111
+ i += 2
112
+ else:
113
+ new_word.append(word[i])
114
+ i += 1
115
+ new_word_freqs[tuple(new_word)] = freq
116
+ return new_word_freqs
117
+ def train_bpe(self, texts: List[str]) -> None:
118
+ print("Training BPE tokenizer...")
119
+ word_freqs = Counter()
120
+ for i, text in enumerate(texts):
121
+ if i % 10000 == 0:
122
+ print(f"Processing text {i}/{len(texts)}")
123
+ tokens = self.pre_tokenize(text)
124
+ for token in tokens:
125
+ char_seq = tuple(token)
126
+ if len(char_seq) > 0:
127
+ word_freqs[char_seq] += 1
128
+ print(f"Found {len(word_freqs)} unique word patterns")
129
+ word_freqs = {word: freq for word, freq in word_freqs.items() if freq >= self.min_freq}
130
+ for term in self.technical_terms:
131
+ if (term,) in word_freqs:
132
+ word_freqs[(term,)] *= 10
133
+ all_chars = set()
134
+ for word in word_freqs:
135
+ all_chars.update(word)
136
+ for char in sorted(all_chars):
137
+ if char not in self.vocab:
138
+ self.vocab[char] = len(self.vocab)
139
+ self.id_to_token[len(self.id_to_token)] = char
140
+ target_vocab_size = self.vocab_size - len(self.special_tokens)
141
+ num_merges = target_vocab_size - len(self.vocab)
142
+ for i in range(num_merges):
143
+ if i % 1000 == 0:
144
+ print(f"BPE merge {i}/{num_merges}")
145
+ pairs = self.get_pairs(word_freqs)
146
+ if not pairs:
147
+ break
148
+ best_pair = pairs.most_common(1)[0][0]
149
+ word_freqs = self.merge_symbols(best_pair, word_freqs)
150
+ merged_token = best_pair[0] + best_pair[1]
151
+ if merged_token not in self.vocab:
152
+ self.vocab[merged_token] = len(self.vocab)
153
+ self.id_to_token[len(self.id_to_token)] = merged_token
154
+ self.bpe_merges.append(best_pair)
155
+ print(f"BPE training complete. Final vocabulary size: {len(self.vocab)}")
156
+ for word, freq in word_freqs.items():
157
+ for token in word:
158
+ self.token_frequencies[token] += freq
159
+ def apply_bpe(self, word: str) -> List[str]:
160
+ if word in self.bpe_cache:
161
+ return self.bpe_cache[word]
162
+ tokens = list(word)
163
+ for merge in self.bpe_merges:
164
+ i = 0
165
+ while i < len(tokens) - 1:
166
+ if tokens[i] == merge[0] and tokens[i + 1] == merge[1]:
167
+ tokens = tokens[:i] + [merge[0] + merge[1]] + tokens[i + 2:]
168
+ else:
169
+ i += 1
170
+ self.bpe_cache[word] = tokens
171
+ return tokens
172
+ def tokenize(self, text: str) -> List[str]:
173
+ pre_tokens = self.pre_tokenize(text)
174
+ final_tokens = []
175
+ for token in pre_tokens:
176
+ if token in self.special_tokens or token in self.vocab:
177
+ final_tokens.append(token)
178
+ else:
179
+ bpe_tokens = self.apply_bpe(token)
180
+ final_tokens.extend(bpe_tokens)
181
+ return final_tokens
182
+ def encode_ids(self, text: str, add_special_tokens: bool = True) -> List[int]:
183
+ tokens = self.tokenize(text)
184
+ if add_special_tokens:
185
+ tokens = ['<bos>'] + tokens + ['<eos>']
186
+ ids = []
187
+ for token in tokens:
188
+ ids.append(self.vocab.get(token, self.vocab['<unk>']))
189
+ return ids
190
+ def decode_ids(self, ids: List[int], skip_special_tokens: bool = True) -> str:
191
+ tokens = []
192
+ for id in ids:
193
+ token = self.id_to_token.get(id, '<unk>')
194
+ if skip_special_tokens and token in self.special_tokens:
195
+ continue
196
+ tokens.append(token)
197
+ text = ''.join(tokens)
198
+ text = text.replace('<|tab|>', '\t')
199
+ text = text.replace('<|newline|>', '\n')
200
+ return text
201
+ def save(self, save_dir: str):
202
+ os.makedirs(save_dir, exist_ok=True)
203
+ with open(os.path.join(save_dir, 'vocab.json'), 'w', encoding='utf-8') as f:
204
+ json.dump(self.vocab, f, indent=2, ensure_ascii=False)
205
+ with open(os.path.join(save_dir, 'merges.txt'), 'w', encoding='utf-8') as f:
206
+ for merge in self.bpe_merges:
207
+ f.write(f"{merge[0]} {merge[1]}\n")
208
+ config = {
209
+ 'vocab_size': self.vocab_size,
210
+ 'min_freq': self.min_freq,
211
+ 'special_tokens': self.special_tokens,
212
+ 'technical_terms': list(self.technical_terms)
213
+ }
214
+ with open(os.path.join(save_dir, 'tokenizer_config.json'), 'w', encoding='utf-8') as f:
215
+ json.dump(config, f, indent=2, ensure_ascii=False)
216
+ with open(os.path.join(save_dir, 'token_frequencies.pkl'), 'wb') as f:
217
+ pickle.dump(dict(self.token_frequencies), f)
218
+ print(f"Tokenizer saved to {save_dir}")
219
+ def load(self, save_dir: str):
220
+ with open(os.path.join(save_dir, 'vocab.json'), 'r', encoding='utf-8') as f:
221
+ self.vocab = json.load(f)
222
+ self.id_to_token = {v: k for k, v in self.vocab.items()}
223
+ with open(os.path.join(save_dir, 'merges.txt'), 'r', encoding='utf-8') as f:
224
+ self.bpe_merges = [tuple(line.strip().split()) for line in f if line.strip()]
225
+ config_file = os.path.join(save_dir, 'tokenizer_config.json')
226
+ if os.path.exists(config_file):
227
+ with open(config_file, 'r', encoding='utf-8') as f:
228
+ config = json.load(f)
229
+ self.vocab_size = config.get('vocab_size', self.vocab_size)
230
+ self.min_freq = config.get('min_freq', self.min_freq)
231
+ if 'technical_terms' in config:
232
+ self.technical_terms = set(config['technical_terms'])
233
+ freq_file = os.path.join(save_dir, 'token_frequencies.pkl')
234
+ if os.path.exists(freq_file):
235
+ with open(freq_file, 'rb') as f:
236
+ self.token_frequencies = Counter(pickle.load(f))
237
+ self.bpe_cache = {}
238
+ print(f"Tokenizer loaded from {save_dir}")
239
+ print(f"Vocabulary size: {len(self.vocab)}")
240
+ print(f"Number of BPE merges: {len(self.bpe_merges)}")
241
+ def get_vocab_size(self) -> int:
242
+ return len(self.vocab)
243
+ def get_token_frequency(self, token: str) -> int:
244
+ return self.token_frequencies.get(token, 0)
245
+ def analyze_tokenization(self, text: str):
246
+ tokens = self.tokenize(text)
247
+ ids = self.encode_ids(text, add_special_tokens=False)
248
+ print(f"Original text: {text}")
249
+ print(f"Tokens: {tokens}")
250
+ print(f"Token IDs: {ids}")
251
+ print(f"Number of tokens: {len(tokens)}")
252
+ print(f"Compression ratio: {len(text.split())/len(tokens):.2f}")
253
+ return tokens, ids
254
+ class ConversationDataset:
255
+ """Dataset class for handling conversation data with the custom tokenizer"""
256
+ def __init__(self, data_file: str, tokenizer: TechnicalTokenizer, max_length: int = 512):
257
+ self.data_file = data_file
258
+ self.tokenizer = tokenizer
259
+ self.max_length = max_length
260
+ self.conversations = []
261
+ self.load_conversations()
262
+ def load_conversations(self):
263
+ print(f"Loading conversations from {self.data_file}")
264
+ if self.data_file.endswith('.jsonl'):
265
+ self.load_jsonl()
266
+ else:
267
+ self.load_text()
268
+ print(f"Loaded {len(self.conversations)} conversations")
269
+ def load_jsonl(self):
270
+ with open(self.data_file, 'r', encoding='utf-8') as f:
271
+ for line in f:
272
+ try:
273
+ conv = json.loads(line.strip())
274
+ messages = conv.get("messages", [])
275
+ if not messages:
276
+ continue
277
+ text_parts = []
278
+ for msg in messages:
279
+ role = msg.get("role", "")
280
+ content = msg.get("content", "").strip()
281
+ if not content:
282
+ continue
283
+ if role == "system":
284
+ continue
285
+ elif role == "user":
286
+ text_parts.append(f"<user> {content}")
287
+ elif role == "assistant":
288
+ text_parts.append(f"<assistant> {content}")
289
+
290
+ if len(text_parts) >= 2:
291
+ conversation_text = " ".join(text_parts) + " <|endoftext|>"
292
+ self.conversations.append(conversation_text)
293
+ except json.JSONDecodeError:
294
+ continue
295
+ def load_text(self):
296
+ with open(self.data_file, 'r', encoding='utf-8') as f:
297
+ content = f.read()
298
+ conversations = content.split('<|endoftext|>\n')
299
+ for conv in conversations:
300
+ conv = conv.strip()
301
+ if conv:
302
+ self.conversations.append(conv + " <|endoftext|>")
303
+ def get_tokenized_conversations(self, include_stats=False):
304
+ tokenized = []
305
+ stats = {'total_tokens': 0, 'truncated': 0, 'avg_length': 0}
306
+ for conv in self.conversations:
307
+ tokens = self.tokenizer.encode_ids(conv)
308
+ if len(tokens) > self.max_length:
309
+ tokens = tokens[:self.max_length]
310
+ stats['truncated'] += 1
311
+ tokenized.append(tokens)
312
+ stats['total_tokens'] += len(tokens)
313
+ if tokenized:
314
+ stats['avg_length'] = stats['total_tokens'] / len(tokenized)
315
+ if include_stats:
316
+ return tokenized, stats
317
+ return tokenized
318
+ def create_training_examples(self, stride: int = None):
319
+ if stride is None:
320
+ stride = self.max_length // 2
321
+ examples = []
322
+ for conv in self.conversations:
323
+ tokens = self.tokenizer.encode_ids(conv)
324
+ if len(tokens) <= self.max_length:
325
+ examples.append(tokens)
326
+ else:
327
+ for i in range(0, len(tokens), stride):
328
+ window = tokens[i:i + self.max_length]
329
+ if len(window) >= 32:
330
+ examples.append(window)
331
+ return examples
332
+ def train_tokenizer_from_files(file_paths: List[str],
333
+ vocab_size: int = 32000,
334
+ min_freq: int = 2,
335
+ output_dir: str = "tokenizer",
336
+ max_texts: int = None):
337
+ print(f"Training tokenizer with vocab_size={vocab_size}")
338
+ print(f"Input files: {file_paths}")
339
+ all_texts = []
340
+ for file_path in file_paths:
341
+ print(f"Loading {file_path}...")
342
+ if file_path.endswith('.jsonl'):
343
+ with open(file_path, 'r', encoding='utf-8') as f:
344
+ for line in f:
345
+ try:
346
+ conv = json.loads(line.strip())
347
+ messages = conv.get("messages", [])
348
+ text_parts = []
349
+ for msg in messages:
350
+ content = msg.get("content", "").strip()
351
+ if content:
352
+ text_parts.append(content)
353
+ if text_parts:
354
+ all_texts.append(" ".join(text_parts))
355
+ except json.JSONDecodeError:
356
+ continue
357
+ else:
358
+ with open(file_path, 'r', encoding='utf-8') as f:
359
+ content = f.read()
360
+ chunks = content.split('\n\n')
361
+ for chunk in chunks:
362
+ if chunk.strip():
363
+ all_texts.append(chunk.strip())
364
+ print(f"Loaded {len(all_texts)} texts")
365
+ if max_texts and len(all_texts) > max_texts:
366
+ import random
367
+ random.shuffle(all_texts)
368
+ all_texts = all_texts[:max_texts]
369
+ print(f"Limited to {len(all_texts)} texts")
370
+ tokenizer = TechnicalTokenizer(vocab_size=vocab_size, min_freq=min_freq)
371
+ tokenizer.train_bpe(all_texts)
372
+ tokenizer.save(output_dir)
373
+ print("\nTesting tokenization on sample texts:")
374
+ test_texts = [
375
+ "Hello, how can I help you with your Python programming question?",
376
+ "The neural network has 3 hidden layers with ReLU activation functions.",
377
+ "```python\ndef fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)\n```",
378
+ "The derivative of x^2 is 2x, and the integral is (x^3)/3 + C."
379
+ ]
380
+ for text in test_texts:
381
+ tokenizer.analyze_tokenization(text)
382
+ print()
383
+ return tokenizer
384
+ def main():
385
+ parser = argparse.ArgumentParser(description="Train custom tokenizer for technical content")
386
+ parser.add_argument("--input_files", nargs='+', help="Input text/jsonl files")
387
+ parser.add_argument("--output_dir", default="tokenizer", help="Output directory for tokenizer")
388
+ parser.add_argument("--vocab_size", type=int, default=32000, help="Vocabulary size")
389
+ parser.add_argument("--min_freq", type=int, default=2, help="Minimum token frequency")
390
+ parser.add_argument("--max_texts", type=int, help="Maximum number of texts to use for training")
391
+ parser.add_argument("--test_file", help="Test file for analyzing tokenization")
392
+ parser.add_argument("--load_tokenizer", help="Load existing tokenizer from directory")
393
+ args = parser.parse_args()
394
+ default_input_file = "/kaggle/input/gpt-based-slm-dataset/slm_training_complete.jsonl"
395
+ default_text_file = "/kaggle/working/text_data/training_data_chat.txt"
396
+ if not args.input_files and not args.load_tokenizer:
397
+ if os.path.exists(default_input_file):
398
+ args.input_files = [default_input_file]
399
+ print(f"No arguments provided, using default input file: {default_input_file}")
400
+ elif os.path.exists(default_text_file):
401
+ args.input_files = [default_text_file]
402
+ print(f"No arguments provided, using default text file: {default_text_file}")
403
+ else:
404
+ parser.error("No input files or tokenizer directory provided, and default files not found. "
405
+ "Please specify --input_files or --load_tokenizer.")
406
+ if args.load_tokenizer:
407
+ tokenizer = TechnicalTokenizer()
408
+ tokenizer.load(args.load_tokenizer)
409
+ if args.test_file:
410
+ print(f"\nTesting on {args.test_file}")
411
+ dataset = ConversationDataset(args.test_file, tokenizer)
412
+ tokenized, stats = dataset.get_tokenized_conversations(include_stats=True)
413
+ print(f"Dataset statistics:")
414
+ print(f" Total conversations: {len(tokenized)}")
415
+ print(f" Total tokens: {stats['total_tokens']:,}")
416
+ print(f" Average tokens per conversation: {stats['avg_length']:.1f}")
417
+ print(f" Conversations truncated: {stats['truncated']}")
418
+ else:
419
+ tokenizer = train_tokenizer_from_files(
420
+ file_paths=args.input_files,
421
+ vocab_size=args.vocab_size,
422
+ min_freq=args.min_freq,
423
+ output_dir=args.output_dir,
424
+ max_texts=args.max_texts
425
+ )
426
+ if args.test_file:
427
+ print(f"\nTesting on {args.test_file}")
428
+ dataset = ConversationDataset(args.test_file, tokenizer)
429
+ tokenized, stats = dataset.get_tokenized_conversations(include_stats=True)
430
+ print(f"Dataset statistics:")
431
+ print(f" Total conversations: {len(tokenized)}")
432
+ print(f" Total tokens: {stats['total_tokens']:,}")
433
+ print(f" Average tokens per conversation: {stats['avg_length']:.1f}")
434
+ print(f" Conversations truncated: {stats['truncated']}")
435
+
436
+ if __name__ == "__main__":
437
+ main()