Spaces:
Running
on
Zero
Running
on
Zero
| from abc import ABC | |
| from abc import abstractmethod | |
| import sentencepiece as spm | |
| from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model | |
| from typing import Any, Union | |
| import numpy as np | |
| from dataclasses import dataclass | |
| def encode_pieces(sp_model: spm.SentencePieceProcessor, text: str, sample=False): | |
| """Encode text into sentence pieces. Only supports py3.""" | |
| if not sample: | |
| pieces = sp_model.EncodeAsPieces(text) | |
| else: | |
| pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1) | |
| return pieces | |
| class AbstractTokenizer(ABC): | |
| """Abstract class for tokenizer.""" | |
| def __init__(self, name): | |
| self.name = name | |
| super().__init__() | |
| def vocab_size(self): | |
| pass | |
| def vocab(self): | |
| """Dictionary from vocab text token to id token.""" | |
| pass | |
| def inv_vocab(self): | |
| """Dictionary from vocab id token to text token.""" | |
| pass | |
| def tokenize(self, text): | |
| pass | |
| def detokenize(self, token_ids): | |
| raise NotImplementedError('detokenizer is not implemented for {} ' | |
| 'tokenizer'.format(self.name)) | |
| def cls(self): | |
| raise NotImplementedError('CLS is not provided for {} ' | |
| 'tokenizer'.format(self.name)) | |
| def sep(self): | |
| raise NotImplementedError('SEP is not provided for {} ' | |
| 'tokenizer'.format(self.name)) | |
| def pad(self): | |
| raise NotImplementedError('PAD is not provided for {} ' | |
| 'tokenizer'.format(self.name)) | |
| def eod(self): | |
| raise NotImplementedError('EOD is not provided for {} ' | |
| 'tokenizer'.format(self.name)) | |
| def mask(self): | |
| raise NotImplementedError('MASK is not provided for {} ' | |
| 'tokenizer'.format(self.name)) | |
| class SPieceTokenizer(AbstractTokenizer): | |
| def __init__(self, spm_file: str): | |
| super().__init__('Sentence Piece') | |
| self.sp_model = spm.SentencePieceProcessor() | |
| self.sp_model.Load(spm_file) | |
| self.eod_id = self.get_token_id('</s>') | |
| self.special_ids = set([ | |
| self.sp_model.pad_id(), | |
| self.sp_model.eos_id(), | |
| self.sp_model.bos_id(), | |
| self.sp_model.unk_id(), | |
| self.eod_id, | |
| ]) | |
| # initialize index_2_bytes | |
| self._initialize_index_2_bytes() | |
| def encode_pieces(self, text: str, sample=False): | |
| if not sample: | |
| pieces = self.sp_model.EncodeAsPieces(text) | |
| else: | |
| pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) | |
| return pieces | |
| def _initialize_index_2_bytes(self): | |
| proto = sp_pb2_model.ModelProto() | |
| proto.ParseFromString(self.sp_model.serialized_model_proto()) | |
| self.index_2_numbytes = [0] * len(proto.pieces) | |
| for i, p in enumerate(proto.pieces): | |
| clean_piece = p.piece.replace('▁', '') | |
| self.index_2_numbytes[i] = len(clean_piece.encode('utf-8')) | |
| def set_add_dummy_prefix(self, add_dummy_prefix: bool = False): | |
| proto = sp_pb2_model.ModelProto() | |
| proto.ParseFromString(self.sp_model.serialized_model_proto()) | |
| if proto.normalizer_spec.add_dummy_prefix != add_dummy_prefix: | |
| proto.normalizer_spec.add_dummy_prefix = add_dummy_prefix | |
| self.sp_model.LoadFromSerializedProto(proto.SerializeToString()) | |
| print(f"> set add_dummy_prefix to {add_dummy_prefix} ...", flush=True) | |
| def add_special_id(self, token_id): | |
| self.special_ids.add(token_id) | |
| def has_dummy_prefix(self): | |
| pieces = self.sp_model.EncodeAsPieces("hello") | |
| return pieces[0].startswith('▁') | |
| def vocab_size(self): | |
| return self.sp_model.GetPieceSize() | |
| def vocab(self): | |
| """Dictionary from vocab text token to id token.""" | |
| return self.sp_model | |
| def get_array_bytes(self, array): | |
| return sum(self.index_2_numbytes[i] if i < self.vocab_size else 2 for i in array) | |
| def tokenize(self, text): | |
| tokens = encode_pieces(self.sp_model, text) | |
| return self.convert_tokens_to_ids(tokens) | |
| def encode(self, text: str, bos: bool=False, eos: bool=False, **kwargs: Any) -> list[int]: | |
| tokens = self.encode_pieces(text) | |
| t = self.convert_tokens_to_ids(tokens) | |
| if bos: | |
| t.insert(0, self.bos_id) | |
| if eos: | |
| t.append(self.eos_id) | |
| return t | |
| def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: | |
| if isinstance(tokens, str): | |
| return self.sp_model.PieceToId(tokens) | |
| return [self.sp_model.PieceToId(token) for token in tokens] | |
| def detokenize(self, token_ids): | |
| if isinstance(token_ids, list): | |
| pieces = [self.sp_model.IdToPiece(id) for id in token_ids] | |
| else: | |
| pieces = [self.sp_model.IdToPiece(id) for id in token_ids.tolist()] | |
| return pieces | |
| def decode(self, token_ids: Union[int, list[int]], skip_special_tokens: bool = False) -> str: | |
| assert not skip_special_tokens, "skip_special_tokens is not supported" | |
| if isinstance(token_ids, (int, np.integer)): | |
| return self.detokenize([int(token_ids)])[0] | |
| return ''.join(self.detokenize(token_ids)) | |
| def get_token_id(self, token): | |
| return self.sp_model.PieceToId(token) | |
| def inv_vocab(self): | |
| # TODO: to be implemented | |
| return {} | |
| def decode_pieces(self, pieces): | |
| return self.sp_model.DecodePieces(pieces) | |
| def eod(self): | |
| return self.eod_id | |
| def pad_id(self): | |
| return self.sp_model.pad_id() | |
| def eos_id(self): | |
| return self.sp_model.eos_id() | |
| def bos_id(self): | |
| return self.sp_model.bos_id() | |
| def unk_id(self): | |
| return self.sp_model.unk_id() | |
| def pad_token_id(self): | |
| return self.pad_id | |
| def eos_token_id(self): | |
| return self.eos_id | |
| class ExtraTokens: | |
| msg_end: int | |
| user_msg_start: int | |
| assistant_msg_start: int | |
| name_end: int | |
| media_begin: int | |
| media_content: int | |
| media_end: int | |
| pad: int | |
| def instantiate_extra_tokens(tokenizer: AbstractTokenizer): | |
| if isinstance(tokenizer, SPieceTokenizer): | |
| map_fn = lambda x: tokenizer.convert_tokens_to_ids(x) | |
| else: | |
| raise ValueError(f"Invalid tokenizer type: {type(tokenizer)}") | |
| return ExtraTokens( | |
| msg_end=map_fn('[extra_id_0]'), | |
| user_msg_start=map_fn('[extra_id_1]'), | |
| assistant_msg_start=map_fn('[extra_id_2]'), | |
| name_end=map_fn('[extra_id_12]'), | |
| media_begin=map_fn('[extra_id_13]'), | |
| media_content=map_fn('[extra_id_14]'), | |
| media_end=map_fn('[extra_id_15]'), | |
| pad=tokenizer.pad_id | |
| ) | |
| def get_tokenizer_and_extra_tokens(): | |
| sp_model_path = "resources/tokenizer/160k.model" | |
| tokenizer = SPieceTokenizer(sp_model_path) | |
| tokenizer.set_add_dummy_prefix(False) | |
| extra_tokens = instantiate_extra_tokens(tokenizer) | |
| return tokenizer, extra_tokens | |