Spaces:
Paused
Paused
| import json | |
| from transformers import AutoTokenizer, T5ForConditionalGeneration | |
| from diffusers.utils import logging | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| def add_special_token(tokenizer, text_encoder, add_color, add_font, color_ann_path, font_ann_path, multilingual=False): | |
| with open(font_ann_path, 'r') as f: | |
| idx_font_dict = json.load(f) | |
| with open(color_ann_path, 'r') as f: | |
| idx_color_dict = json.load(f) | |
| if multilingual: | |
| font_token = [] | |
| for font_code in idx_font_dict: | |
| prefix = font_code[:2] | |
| font_token.append(f'<{prefix}-font-{idx_font_dict[font_code]}>') | |
| else: | |
| font_token = [f'<font-{i}>' for i in range(len(idx_font_dict))] | |
| color_token = [f'<color-{i}>' for i in range(len(idx_color_dict))] | |
| additional_special_tokens = [] | |
| if add_color: | |
| additional_special_tokens += color_token | |
| if add_font: | |
| additional_special_tokens += font_token | |
| tokenizer.add_tokens(additional_special_tokens, special_tokens=True) | |
| text_encoder.resize_token_embeddings(len(tokenizer)) | |
| def load_byt5_and_byt5_tokenizer( | |
| byt5_name='google/byt5-small', | |
| special_token=False, | |
| color_special_token=False, | |
| font_special_token=False, | |
| color_ann_path='assets/color_idx.json', | |
| font_ann_path='assets/font_idx_512.json', | |
| huggingface_cache_dir=None, | |
| multilingual=False, | |
| ): | |
| byt5_tokenizer = AutoTokenizer.from_pretrained( | |
| byt5_name, cache_dir=huggingface_cache_dir, | |
| ) | |
| byt5_text_encoder = T5ForConditionalGeneration.from_pretrained( | |
| byt5_name, cache_dir=huggingface_cache_dir, | |
| ).get_encoder() | |
| if special_token: | |
| add_special_token( | |
| byt5_tokenizer, | |
| byt5_text_encoder, | |
| add_color=color_special_token, | |
| add_font=font_special_token, | |
| color_ann_path=color_ann_path, | |
| font_ann_path=font_ann_path, | |
| multilingual=multilingual, | |
| ) | |
| logger.info(f'Loaded original byt5 weight') | |
| return byt5_text_encoder, byt5_tokenizer |