Spaces:
Runtime error
Runtime error
| import json | |
| import math | |
| import os | |
| import torch | |
| import natsort | |
| from vita_audio.tokenizer import get_audio_tokenizer | |
| class AudioProcessor: | |
| def __init__( | |
| self, | |
| audio_tokenizer_path=None, | |
| audio_tokenizer_type=None, | |
| text_audio_interval_ratio=None, | |
| ): | |
| self.audio_tokenizer = get_audio_tokenizer( | |
| audio_tokenizer_path, | |
| audio_tokenizer_type, | |
| ) | |
| self.audio_tokenizer_type = audio_tokenizer_type | |
| self.text_audio_interval_ratio = text_audio_interval_ratio | |
| # self.load_model() | |
| def load_model(self): | |
| if self.audio_tokenizer is not None: | |
| self.audio_tokenizer.load_model() | |
| def process_audios(self, audio_path, is_discrete=False, is_contiguous=False, **kwargs): | |
| assert not (is_discrete and is_contiguous) | |
| assert is_discrete or is_contiguous | |
| if is_discrete: | |
| audio_tokenizer_type = self.audio_tokenizer_type.split("_")[-1] | |
| cache_path = os.path.splitext(audio_path)[0] + f"_{audio_tokenizer_type}.json" | |
| try: | |
| if os.path.isfile(cache_path): | |
| with open(cache_path, "r") as f: | |
| audio_data = json.load(f) | |
| return audio_data | |
| except Exception as e: | |
| pass | |
| audio_data = self.audio_tokenizer.encode( | |
| audio_path, is_discrete=is_discrete, is_contiguous=is_contiguous, **kwargs | |
| ) | |
| # print(f"{len(audio_data)=}") | |
| if is_discrete: | |
| try: | |
| if isinstance(audio_data, list): | |
| with open(cache_path, "w") as f: | |
| json.dump(audio_data, f) | |
| except Exception as e: | |
| pass | |
| return audio_data | |
| def is_discrete(self): | |
| return self.audio_tokenizer.is_discrete | |
| def is_contiguous(self): | |
| return self.audio_tokenizer.is_contiguous | |
| def apply_to_role(self, role, **kwargs): | |
| return self.audio_tokenizer.apply_to_role(role, **kwargs) | |
| def text_audio_interval(self, content_input_id, AUD_START_ID, AUD_END_ID): | |
| return text_audio_interval( | |
| content_input_id, | |
| AUD_START_ID, | |
| AUD_END_ID, | |
| self.text_audio_interval_ratio, | |
| ) | |
| def add_audio_input_contiguous(input_ids, audio_paths, tokenizer, audio_tokenizer): | |
| from ...constants import ( | |
| AUD_START_TOKEN, | |
| AUD_END_TOKEN, | |
| AUD_TAG_TOKEN, | |
| AUD_CONTEXT_TOKEN, | |
| ) | |
| AUD_CONTEXT_ID = tokenizer(AUD_CONTEXT_TOKEN, add_special_tokens=False).input_ids | |
| AUD_TAG_ID = tokenizer(AUD_TAG_TOKEN, add_special_tokens=False).input_ids | |
| AUD_START_ID = tokenizer(AUD_START_TOKEN, add_special_tokens=False).input_ids | |
| AUD_END_ID = tokenizer(AUD_END_TOKEN, add_special_tokens=False).input_ids | |
| AUD_CONTEXT_ID = AUD_CONTEXT_ID[0] | |
| AUD_TAG_ID = AUD_TAG_ID[0] | |
| AUD_START_ID = AUD_START_ID[0] | |
| AUD_END_ID = AUD_END_ID[0] | |
| aud_positions = [i for i, x in enumerate(input_ids) if x == AUD_TAG_ID] | |
| audios = [] | |
| audio_indices = [] | |
| new_input_ids = [] | |
| st = 0 | |
| for aud_idx, aud_pos in enumerate(aud_positions): | |
| audio = audio_tokenizer.encode(audio_paths[aud_idx], is_contiguous=True) | |
| audios.append(audio) | |
| audio_token_length = audio.size(0) + 4 | |
| new_input_ids += input_ids[st:aud_pos] | |
| new_input_ids += [AUD_START_ID] | |
| audio_indice_b = torch.zeros( | |
| 1, audio_token_length, dtype=torch.int64 | |
| ) # This will change in collate_fn | |
| audio_indice_s = ( | |
| torch.arange(len(new_input_ids), len(new_input_ids) + audio_token_length) | |
| .unsqueeze(0) | |
| .repeat(1, 1) | |
| ) | |
| audio_indice_b_s = torch.stack( | |
| [audio_indice_b, audio_indice_s], dim=0 | |
| ) # 2, num_image, image_length | |
| audio_indices.append(audio_indice_b_s) | |
| new_input_ids += [AUD_CONTEXT_ID] * audio_token_length | |
| new_input_ids += [AUD_END_ID] | |
| st = aud_pos + 1 | |
| new_input_ids += input_ids[st:] | |
| inputs_ids = new_input_ids | |
| return inputs_ids, audios, audio_indices | |
| def text_audio_interval_old(input_ids, AUD_START_ID, AUD_END_ID, text_audio_interval_ratio): | |
| if text_audio_interval_ratio is not None: | |
| text_num, audio_num = text_audio_interval_ratio | |
| else: | |
| text_num = 13 | |
| audio_num = 26 | |
| text_num = 4 | |
| audio_num = 10 | |
| # exclude AUD_START and AUD_END | |
| audio_num = audio_num - 2 | |
| st = [i for i, x in enumerate(input_ids) if x == AUD_START_ID] | |
| ed = [i for i, x in enumerate(input_ids) if x == AUD_END_ID] | |
| # only text | |
| if len(st) == 0 and len(ed) == 0: | |
| return input_ids | |
| assert len(st) == 1 | |
| assert len(ed) == 1 | |
| st = st[0] | |
| ed = ed[0] | |
| assert st < ed | |
| # only audio | |
| if st == 0 and ed == len(input_ids) - 1: | |
| return input_ids | |
| audio_tokens = input_ids[st + 1 : ed] | |
| text_tokens = input_ids[:st] + input_ids[ed + 1 :] | |
| if False: | |
| audio_tokens_chunks = [ | |
| audio_tokens[i : i + audio_num] for i in range(0, len(audio_tokens), audio_num) | |
| ] | |
| text_tokens_chunks = [ | |
| text_tokens[i : i + text_num] for i in range(0, len(text_tokens), text_num) | |
| ] | |
| if False: | |
| # [0 1] [2 3 4 5 6 audio_num-1] ... | |
| audio_tokens_chunks = [audio_tokens[:2], audio_tokens[2:audio_num]] + [ | |
| audio_tokens[i : i + audio_num] for i in range(audio_num, len(audio_tokens), audio_num) | |
| ] | |
| # [0] [1 2 text_num-1] ... | |
| text_tokens_chunks = [text_tokens[:1], text_tokens[1:text_num]] + [ | |
| text_tokens[i : i + text_num] for i in range(text_num, len(text_tokens), text_num) | |
| ] | |
| if True: | |
| # [0 1 2 3 4 5 6 audio_num] [] ... | |
| audio_tokens_chunks = [audio_tokens[:audio_num]] + [ | |
| audio_tokens[i : i + audio_num] for i in range(audio_num, len(audio_tokens), audio_num) | |
| ] | |
| # [0] [] ... | |
| text_tokens_chunks = [text_tokens[:1]] + [ | |
| text_tokens[i : i + text_num] for i in range(1, len(text_tokens), text_num) | |
| ] | |
| chunk_num = min(len(audio_tokens_chunks), len(text_tokens_chunks)) | |
| audio_tokens_chunks = audio_tokens_chunks[: chunk_num - 1] + [ | |
| sum(audio_tokens_chunks[chunk_num - 1 :], []) | |
| ] | |
| text_tokens_chunks = text_tokens_chunks[: chunk_num - 1] + [ | |
| sum(text_tokens_chunks[chunk_num - 1 :], []) | |
| ] | |
| interval_input_ids = [] | |
| for text_tokens, audio_tokens in zip(text_tokens_chunks, audio_tokens_chunks): | |
| interval_input_ids += text_tokens + [AUD_START_ID] + audio_tokens + [AUD_END_ID] | |
| # interval_input_ids += text_tokens + audio_tokens | |
| return interval_input_ids | |
| def text_audio_interval(input_ids, AUD_START_ID, AUD_END_ID, text_audio_interval_ratio): | |
| if text_audio_interval_ratio is None: | |
| # T A | |
| text_audio_interval_ratio = [13, 26] | |
| # T A T A T A | |
| text_audio_interval_ratio = [1, 4, 3, 8, 4, 10] | |
| # T A T A | |
| text_audio_interval_ratio = [1, 10, 4, 10] | |
| text_nums = text_audio_interval_ratio[::2] | |
| audio_nums = text_audio_interval_ratio[1::2] | |
| # exclude AUD_START and AUD_END | |
| audio_nums = [x - 2 for x in audio_nums] | |
| st = [i for i, x in enumerate(input_ids) if x == AUD_START_ID] | |
| ed = [i for i, x in enumerate(input_ids) if x == AUD_END_ID] | |
| # only text | |
| if len(st) == 0 and len(ed) == 0: | |
| return input_ids | |
| assert len(st) == 1 | |
| assert len(ed) == 1 | |
| st = st[0] | |
| ed = ed[0] | |
| assert st < ed | |
| # only audio | |
| if st == 0 and ed == len(input_ids) - 1: | |
| return input_ids | |
| audio_tokens = input_ids[st + 1 : ed] | |
| text_tokens = input_ids[:st] + input_ids[ed + 1 :] | |
| audio_tokens_chunks = [] | |
| while len(audio_tokens) > 0: | |
| if len(audio_nums) > 1: | |
| audio_num = audio_nums.pop(0) | |
| else: | |
| audio_num = audio_nums[0] | |
| audio_tokens_chunks.append(audio_tokens[:audio_num]) | |
| audio_tokens = audio_tokens[audio_num:] | |
| text_tokens_chunks = [] | |
| while len(text_tokens) > 0: | |
| if len(text_nums) > 1: | |
| text_num = text_nums.pop(0) | |
| else: | |
| text_num = text_nums[0] | |
| text_tokens_chunks.append(text_tokens[:text_num]) | |
| text_tokens = text_tokens[text_num:] | |
| chunk_num = min(len(audio_tokens_chunks), len(text_tokens_chunks)) | |
| audio_tokens_chunks = audio_tokens_chunks[: chunk_num - 1] + [ | |
| sum(audio_tokens_chunks[chunk_num - 1 :], []) | |
| ] | |
| text_tokens_chunks = text_tokens_chunks[: chunk_num - 1] + [ | |
| sum(text_tokens_chunks[chunk_num - 1 :], []) | |
| ] | |
| interval_input_ids = [] | |
| for text_tokens, audio_tokens in zip(text_tokens_chunks, audio_tokens_chunks): | |
| interval_input_ids += text_tokens + [AUD_START_ID] + audio_tokens + [AUD_END_ID] | |
| # interval_input_ids += text_tokens + audio_tokens | |
| return interval_input_ids | |