fix: make `eos_token`/`pad_token` overridable and add `pickle` support
Browse files- tokenization_arcade100k.py +17 -3
tokenization_arcade100k.py
CHANGED
|
@@ -124,8 +124,12 @@ class Arcade100kTokenizer(PreTrainedTokenizer):
|
|
| 124 |
|
| 125 |
self.decoder = {i: n for n, i in self.tokenizer._mergeable_ranks.items()}
|
| 126 |
self.decoder.update({i: n for n, i in self.tokenizer._special_tokens.items()})
|
| 127 |
-
|
| 128 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
# Expose for convenience
|
| 130 |
self.mergeable_ranks = self.tokenizer._mergeable_ranks
|
| 131 |
self.special_tokens = self.tokenizer._special_tokens
|
|
@@ -133,6 +137,16 @@ class Arcade100kTokenizer(PreTrainedTokenizer):
|
|
| 133 |
def __len__(self):
|
| 134 |
return self.tokenizer.n_vocab
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
@property
|
| 137 |
def vocab_size(self):
|
| 138 |
return self.tokenizer.n_vocab
|
|
@@ -273,4 +287,4 @@ class Arcade100kTokenizer(PreTrainedTokenizer):
|
|
| 273 |
token_ids = [token_ids]
|
| 274 |
if skip_special_tokens:
|
| 275 |
token_ids = [i for i in token_ids if i < self.tokenizer.eot_token]
|
| 276 |
-
return self.tokenizer.decode(token_ids)
|
|
|
|
| 124 |
|
| 125 |
self.decoder = {i: n for n, i in self.tokenizer._mergeable_ranks.items()}
|
| 126 |
self.decoder.update({i: n for n, i in self.tokenizer._special_tokens.items()})
|
| 127 |
+
# Provide default `eos_token` and `pad_token`
|
| 128 |
+
if self.eos_token is None:
|
| 129 |
+
self.eos_token = self.decoder[self.tokenizer.eot_token]
|
| 130 |
+
if self.pad_token is None:
|
| 131 |
+
self.pad_token = self.decoder[self.tokenizer.pad_token]
|
| 132 |
+
|
| 133 |
# Expose for convenience
|
| 134 |
self.mergeable_ranks = self.tokenizer._mergeable_ranks
|
| 135 |
self.special_tokens = self.tokenizer._special_tokens
|
|
|
|
| 137 |
def __len__(self):
|
| 138 |
return self.tokenizer.n_vocab
|
| 139 |
|
| 140 |
+
def __getstate__(self):
|
| 141 |
+
# Required for `pickle` support
|
| 142 |
+
state = self.__dict__.copy()
|
| 143 |
+
del state["tokenizer"]
|
| 144 |
+
return state
|
| 145 |
+
|
| 146 |
+
def __setstate__(self, state):
|
| 147 |
+
self.__dict__.update(state)
|
| 148 |
+
self.tokenizer = tiktoken.Encoding(**self._tiktoken_config)
|
| 149 |
+
|
| 150 |
@property
|
| 151 |
def vocab_size(self):
|
| 152 |
return self.tokenizer.n_vocab
|
|
|
|
| 287 |
token_ids = [token_ids]
|
| 288 |
if skip_special_tokens:
|
| 289 |
token_ids = [i for i in token_ids if i < self.tokenizer.eot_token]
|
| 290 |
+
return self.tokenizer.decode(token_ids)
|