Spaces:
Sleeping
Sleeping
Update custom_tokenizers.py
Browse files- custom_tokenizers.py +23 -28
custom_tokenizers.py
CHANGED
|
@@ -1,9 +1,5 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
from transformers import T5Tokenizer, PreTrainedTokenizer
|
| 6 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 7 |
import os
|
| 8 |
import logging
|
| 9 |
|
|
@@ -27,7 +23,6 @@ class Byt5LangTokenizer(T5Tokenizer):
|
|
| 27 |
sp_model_kwargs=None,
|
| 28 |
**kwargs
|
| 29 |
):
|
| 30 |
-
# Базовое определение на основе T5Tokenizer
|
| 31 |
super().__init__(
|
| 32 |
vocab_file=vocab_file,
|
| 33 |
tokenizer_file=tokenizer_file,
|
|
@@ -40,47 +35,47 @@ class Byt5LangTokenizer(T5Tokenizer):
|
|
| 40 |
**kwargs
|
| 41 |
)
|
| 42 |
|
| 43 |
-
#
|
| 44 |
-
self.
|
| 45 |
-
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
| 46 |
|
| 47 |
@property
|
| 48 |
def vocab_size(self):
|
| 49 |
-
# ByT5 использует байтовое
|
| 50 |
return 256 + self.num_special_tokens
|
| 51 |
|
| 52 |
-
def get_vocab(self):
|
| 53 |
-
#
|
| 54 |
-
vocab = {
|
| 55 |
vocab.update(self.special_tokens_encoder)
|
| 56 |
return vocab
|
| 57 |
|
| 58 |
-
def _tokenize(self, text):
|
| 59 |
-
#
|
| 60 |
return list(text.encode("utf-8"))
|
| 61 |
|
| 62 |
-
def _convert_token_to_id(self, token):
|
| 63 |
-
# Преобразуем токен в ID
|
| 64 |
if isinstance(token, str):
|
| 65 |
if token in self.special_tokens_encoder:
|
| 66 |
return self.special_tokens_encoder[token]
|
| 67 |
else:
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
| 69 |
return token
|
| 70 |
|
| 71 |
-
def _convert_id_to_token(self, index):
|
| 72 |
-
# Преобразуем ID в токен
|
| 73 |
if index in self.special_tokens_decoder:
|
| 74 |
return self.special_tokens_decoder[index]
|
| 75 |
else:
|
| 76 |
return chr(index)
|
| 77 |
|
| 78 |
-
def convert_tokens_to_string(self, tokens):
|
| 79 |
-
#
|
| 80 |
-
|
| 81 |
for token in tokens:
|
| 82 |
-
if token
|
| 83 |
-
|
| 84 |
else:
|
| 85 |
-
|
| 86 |
-
return
|
|
|
|
| 1 |
+
from transformers import T5Tokenizer
|
| 2 |
+
from typing import Dict, List, Optional, Union
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import os
|
| 4 |
import logging
|
| 5 |
|
|
|
|
| 23 |
sp_model_kwargs=None,
|
| 24 |
**kwargs
|
| 25 |
):
|
|
|
|
| 26 |
super().__init__(
|
| 27 |
vocab_file=vocab_file,
|
| 28 |
tokenizer_file=tokenizer_file,
|
|
|
|
| 35 |
**kwargs
|
| 36 |
)
|
| 37 |
|
| 38 |
+
# Создаем byte_decoder — это ключевой недостающий элемент
|
| 39 |
+
self.byte_decoder = {i: bytes([i]) for i in range(256)}
|
|
|
|
| 40 |
|
| 41 |
@property
|
| 42 |
def vocab_size(self):
|
| 43 |
+
# ByT5 использует байтовое представление (256) + специальные токены
|
| 44 |
return 256 + self.num_special_tokens
|
| 45 |
|
| 46 |
+
def get_vocab(self) -> Dict[str, int]:
|
| 47 |
+
# Словарь из байтовых строк и специальных токенов
|
| 48 |
+
vocab = {chr(i): i for i in range(256)}
|
| 49 |
vocab.update(self.special_tokens_encoder)
|
| 50 |
return vocab
|
| 51 |
|
| 52 |
+
def _tokenize(self, text: str) -> List[Union[int, str]]:
|
| 53 |
+
# Превращает текст в последовательность байт (int), как делает ByT5
|
| 54 |
return list(text.encode("utf-8"))
|
| 55 |
|
| 56 |
+
def _convert_token_to_id(self, token: Union[str, int]) -> int:
|
|
|
|
| 57 |
if isinstance(token, str):
|
| 58 |
if token in self.special_tokens_encoder:
|
| 59 |
return self.special_tokens_encoder[token]
|
| 60 |
else:
|
| 61 |
+
try:
|
| 62 |
+
return ord(token)
|
| 63 |
+
except TypeError:
|
| 64 |
+
return token
|
| 65 |
return token
|
| 66 |
|
| 67 |
+
def _convert_id_to_token(self, index: int) -> Union[str, int]:
|
|
|
|
| 68 |
if index in self.special_tokens_decoder:
|
| 69 |
return self.special_tokens_decoder[index]
|
| 70 |
else:
|
| 71 |
return chr(index)
|
| 72 |
|
| 73 |
+
def convert_tokens_to_string(self, tokens: List[Union[str, int]]) -> str:
|
| 74 |
+
# Преобразует список токенов обратно в строку
|
| 75 |
+
decoded = b""
|
| 76 |
for token in tokens:
|
| 77 |
+
if isinstance(token, int):
|
| 78 |
+
decoded += bytes([token])
|
| 79 |
else:
|
| 80 |
+
decoded += token.encode("utf-8")
|
| 81 |
+
return decoded.decode("utf-8", errors="replace")
|