| import re | |
| from sentencepiece import SentencePieceProcessor | |
| def replace_spaces_with_blank(match: re.Match[str]): | |
| return f"<|blank_{len(match.group())}|>" | |
| def replace_blank_with_spaces(match: re.Match[str]): | |
| return " " * int(match.group(1)) | |
| class ChatGLMTokenizer: | |
| def __init__(self, vocab_file): | |
| assert vocab_file is not None | |
| self.vocab_file = vocab_file | |
| self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "<unused_0>", "<sop>", "<eop>", "<ENC>", "<dBLOCK>"] | |
| self.text_tokenizer = SentencePieceProcessor(str(vocab_file)) | |
| def __len__(self): | |
| return len(self.text_tokenizer) | |
| def __getitem__(self, key: str): | |
| return self.text_tokenizer[key] | |
| def preprocess(self, text: str, linebreak=True, whitespaces=True): | |
| if linebreak: | |
| text = text.replace("\n", "<n>") | |
| if whitespaces: | |
| text = text.replace("\t", "<|tab|>") | |
| text = re.sub(r" {2,80}", replace_spaces_with_blank, text) | |
| return text | |
| def encode( | |
| self, text: str, text_pair: str = None, | |
| linebreak=True, whitespaces=True, | |
| add_dummy_prefix=True, special_tokens=True, | |
| ) -> tuple[list[int], list[int]]: | |
| """ | |
| text: Text to encode. Bidirectional part with a [gMASK] and an <sop> for causal LM. | |
| text_pair: causal LM part. | |
| linebreak: Whether to encode newline (\n) in text. | |
| whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding. | |
| special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text. | |
| add_dummy_prefix: Whether to add dummy blank space in the beginning. | |
| """ | |
| text = self.preprocess(text, linebreak, whitespaces) | |
| if not add_dummy_prefix: | |
| text = "<n>" + text | |
| tokens = self.text_tokenizer.encode(text) | |
| prefix_mask = [1] * len(tokens) | |
| if special_tokens: | |
| tokens += [self.text_tokenizer["[gMASK]"], self.text_tokenizer["<sop>"]] | |
| prefix_mask += [1, 0] | |
| if text_pair is not None: | |
| text_pair = self.preprocess(text_pair, linebreak, whitespaces) | |
| pair_tokens = self.text_tokenizer.encode(text_pair) | |
| tokens += pair_tokens | |
| prefix_mask += [0] * len(pair_tokens) | |
| if special_tokens: | |
| tokens += [self.text_tokenizer["<eop>"]] | |
| prefix_mask += [0] | |
| return (tokens if add_dummy_prefix else tokens[2:]), prefix_mask | |
| def decode(self, text_ids: list[int]) -> str: | |
| text = self.text_tokenizer.decode(text_ids) | |
| text = text.replace("<n>", "\n") | |
| text = text.replace("<|tab|>", "\t") | |
| text = re.sub(r"<\|blank_(\d\d?)\|>", replace_blank_with_spaces, text) | |
| return text | |