Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| import argparse | |
| import numpy as np | |
| import copy | |
| import gradio as gr | |
| import re | |
| import torchaudio | |
| import io | |
| import cv2 | |
| import time | |
| import math | |
| from numba import jit | |
| import spaces | |
| from huggingface_hub import snapshot_download | |
| from vita.constants import ( | |
| DEFAULT_AUDIO_TOKEN, | |
| DEFAULT_IMAGE_TOKEN, | |
| DEFAULT_VIDEO_TOKEN, | |
| IGNORE_INDEX, | |
| IMAGE_TOKEN_INDEX, | |
| MAX_IMAGE_LENGTH, | |
| MIN_IMAGE_LENGTH, | |
| ) | |
| from vita.conversation import conv_templates, SeparatorStyle | |
| from vita.model.builder import load_pretrained_model | |
| from vita.util.mm_utils import ( | |
| KeywordsStoppingCriteria, | |
| get_model_name_from_path, | |
| tokenizer_image_token, | |
| tokenizer_image_audio_token, | |
| ) | |
| from vita.util.utils import disable_torch_init | |
| from PIL import Image | |
| from decord import VideoReader, cpu | |
| from vita.model.vita_tts.decoder.llm2tts import llm2TTS | |
| from vita.model.language_model.vita_qwen2 import VITAQwen2Config, VITAQwen2ForCausalLM | |
| from vita.util.data_utils_video_audio_neg_patch import dynamic_preprocess | |
| from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoFeatureExtractor | |
| decoder_topk = 2 | |
| codec_chunk_size = 40 | |
| codec_padding_size = 10 | |
| PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def float_to_int16(audio: np.ndarray) -> np.ndarray: | |
| am = int(math.ceil(float(np.abs(audio).max())) * 32768) | |
| am = 32767 * 32768 // am | |
| return np.multiply(audio, am).astype(np.int16) | |
| def remove_special_characters(input_str): | |
| # Remove special tokens | |
| special_tokens = ['☞', '☟', '☜', '<unk>', '<|im_end|>'] | |
| for token in special_tokens: | |
| input_str = input_str.replace(token, '') | |
| return input_str | |
| def replace_equation(sentence): | |
| special_notations = { | |
| "sin": " sine ", | |
| "cos": " cosine ", | |
| "tan": " tangent ", | |
| "cot": " cotangent ", | |
| "sec": " secant ", | |
| "csc": " cosecant ", | |
| "log": " logarithm ", | |
| "exp": "e^", | |
| "sqrt": "根号 ", | |
| "abs": "绝对值 ", | |
| } | |
| special_operators = { | |
| "+": "加", | |
| "-": "减", | |
| "*": "乘", | |
| "/": "除", | |
| "=": "等于", | |
| '!=': '不等于', | |
| '>': '大于', | |
| '<': '小于', | |
| '>=': '大于等于', | |
| '<=': '小于等于', | |
| } | |
| greek_letters = { | |
| "α": "alpha ", | |
| "β": "beta ", | |
| "γ": "gamma ", | |
| "δ": "delta ", | |
| "ε": "epsilon ", | |
| "ζ": "zeta ", | |
| "η": "eta ", | |
| "θ": "theta ", | |
| "ι": "iota ", | |
| "κ": "kappa ", | |
| "λ": "lambda ", | |
| "μ": "mu ", | |
| "ν": "nu ", | |
| "ξ": "xi ", | |
| "ο": "omicron ", | |
| "π": "派 ", | |
| "ρ": "rho ", | |
| "σ": "sigma ", | |
| "τ": "tau ", | |
| "υ": "upsilon ", | |
| "φ": "phi ", | |
| "χ": "chi ", | |
| "ψ": "psi ", | |
| "ω": "omega " | |
| } | |
| sentence = sentence.replace('**', ' ') | |
| sentence = re.sub(r'(?<![\d)])-(\d+)', r'负\1', sentence) | |
| for key in special_notations: | |
| sentence = sentence.replace(key, special_notations[key]) | |
| for key in special_operators: | |
| sentence = sentence.replace(key, special_operators[key]) | |
| for key in greek_letters: | |
| sentence = sentence.replace(key, greek_letters[key]) | |
| sentence = re.sub(r'\(?(\d+)\)?\((\d+)\)', r'\1乘\2', sentence) | |
| sentence = re.sub(r'\(?(\w+)\)?\^\(?(\w+)\)?', r'\1的\2次方', sentence) | |
| return sentence | |
| def is_video(file_path): | |
| video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', '.webm'} | |
| _, ext = os.path.splitext(file_path) | |
| return ext.lower() in video_extensions | |
| def is_image(file_path): | |
| image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff'} | |
| _, ext = os.path.splitext(file_path) | |
| return ext.lower() in image_extensions | |
| def is_wav(file_path): | |
| wav_extensions = {'.wav'} | |
| _, ext = os.path.splitext(file_path) | |
| return ext.lower() in wav_extensions | |
| def load_model_embemding(model_path): | |
| config_path = os.path.join(model_path, 'config.json') | |
| config = VITAQwen2Config.from_pretrained(config_path) | |
| model = VITAQwen2ForCausalLM.from_pretrained(model_path, config=config, low_cpu_mem_usage=True) | |
| embedding = model.get_input_embeddings() | |
| del model | |
| return embedding | |
| def split_into_sentences(text): | |
| sentence_endings = re.compile(r'[,。?\n!?、,?.!]') | |
| sentences = sentence_endings.split(text) | |
| return [sentence.strip() for sentence in sentences if sentence.strip()] | |
| def convert_webm_to_mp4(input_file, output_file): | |
| try: | |
| cap = cv2.VideoCapture(input_file) | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_file, fourcc, 20.0, (int(cap.get(3)), int(cap.get(4)))) | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| out.write(frame) | |
| cap.release() | |
| out.release() | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| raise | |
| def _get_rawvideo_dec( | |
| video_path, | |
| image_processor=None, | |
| max_frames=MAX_IMAGE_LENGTH, | |
| min_frames=MIN_IMAGE_LENGTH, | |
| image_resolution=384, | |
| video_framerate=1, | |
| s=None, | |
| e=None, | |
| image_aspect_ratio="pad", | |
| ): | |
| # speed up video decode via decord. | |
| if s is None: | |
| start_time, end_time = None, None | |
| else: | |
| start_time = int(s) | |
| end_time = int(e) | |
| start_time = start_time if start_time >= 0.0 else 0.0 | |
| end_time = end_time if end_time >= 0.0 else 0.0 | |
| if start_time > end_time: | |
| start_time, end_time = end_time, start_time | |
| elif start_time == end_time: | |
| end_time = start_time + 1 | |
| if os.path.exists(video_path): | |
| vreader = VideoReader(video_path, ctx=cpu(0)) | |
| else: | |
| raise FileNotFoundError | |
| fps = vreader.get_avg_fps() | |
| f_start = 0 if start_time is None else int(start_time * fps) | |
| f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1)) | |
| num_frames = f_end - f_start + 1 | |
| if num_frames > 0: | |
| # T x 3 x H x W | |
| sample_fps = int(video_framerate) | |
| t_stride = int(round(float(fps) / sample_fps)) | |
| all_pos = list(range(f_start, f_end + 1, t_stride)) | |
| if len(all_pos) > max_frames: | |
| sample_pos = [ | |
| all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int) | |
| ] | |
| elif len(all_pos) < min_frames: | |
| sample_pos = [ | |
| all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=min_frames, dtype=int) | |
| ] | |
| else: | |
| sample_pos = all_pos | |
| patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()] | |
| if image_aspect_ratio == "pad": | |
| def expand2square(pil_img, background_color): | |
| width, height = pil_img.size | |
| if width == height: | |
| return pil_img | |
| elif width > height: | |
| result = Image.new(pil_img.mode, (width, width), background_color) | |
| result.paste(pil_img, (0, (width - height) // 2)) | |
| return result | |
| else: | |
| result = Image.new(pil_img.mode, (height, height), background_color) | |
| result.paste(pil_img, ((height - width) // 2, 0)) | |
| return result | |
| patch_images = [ | |
| expand2square(i, tuple(int(x * 255) for x in image_processor.image_mean)) | |
| for i in patch_images | |
| ] | |
| patch_images = [ | |
| image_processor.preprocess(i, return_tensors="pt")["pixel_values"][0] | |
| for i in patch_images | |
| ] | |
| else: | |
| patch_images = [ | |
| image_processor.preprocess(i, return_tensors="pt")["pixel_values"][0] | |
| for i in patch_images | |
| ] | |
| patch_images = torch.stack(patch_images) | |
| slice_len = patch_images.shape[0] | |
| return patch_images, slice_len | |
| else: | |
| print(f"video path: {video_path} error.") | |
| def _parse_text(text): | |
| lines = text.split("\n") | |
| lines = [line for line in lines if line != ""] | |
| count = 0 | |
| for i, line in enumerate(lines): | |
| if "```" in line: | |
| count += 1 | |
| items = line.split("`") | |
| if count % 2 == 1: | |
| lines[i] = f'<pre><code class="language-{items[-1]}">' | |
| else: | |
| lines[i] = "<br></code></pre>" | |
| else: | |
| if i > 0 and count % 2 == 1: | |
| line = line.replace("`", r"\`") | |
| line = line.replace("<", "<") | |
| line = line.replace(">", ">") | |
| line = line.replace(" ", " ") | |
| line = line.replace("*", "*") | |
| line = line.replace("_", "_") | |
| line = line.replace("-", "-") | |
| line = line.replace(".", ".") | |
| line = line.replace("!", "!") | |
| line = line.replace("(", "(") | |
| line = line.replace(")", ")") | |
| line = line.replace("$", "$") | |
| lines[i] = "<br>" + line | |
| return "".join(lines) | |
| MODEL_NAME = "VITA-MLLM/VITA-1.5" | |
| model_path = snapshot_download(MODEL_NAME, local_dir="VITA_ckpt") | |
| model_type = "qwen2p5_instruct" | |
| tokenizer, model, feature_extractor, context_len = load_pretrained_model( | |
| model_path, model_base=None, model_name="VITA-1.5", model_type="qwen2p5_instruct" | |
| ) | |
| model.resize_token_embeddings(len(tokenizer)) | |
| vision_tower = model.get_vision_tower() | |
| if not vision_tower.is_loaded: | |
| vision_tower.load_model() | |
| image_processor = vision_tower.image_processor | |
| audio_encoder = model.get_audio_encoder() | |
| audio_encoder.to(dtype=torch.float16) | |
| audio_processor = audio_encoder.audio_processor | |
| model.eval() | |
| tts = llm2TTS(os.path.join(model_path, 'vita_tts_ckpt/')) | |
| llm_embedding = load_model_embemding(model_path).to(device) | |
| def predict(_chatbot, task_history): | |
| chat_query = task_history[-1][0] | |
| print(task_history) | |
| conv_mode = "qwen2p5_instruct" | |
| conv = conv_templates[conv_mode].copy() | |
| all_audio_path = [] | |
| all_visual_tensor = [] | |
| qs = '' | |
| input_mode = 'lang' | |
| for i, (q, a) in enumerate(task_history): | |
| if isinstance(q, (tuple, list)): | |
| if is_image(q[0]): | |
| image = Image.open(q[0]).convert("RGB") | |
| image, p_num = dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=True) | |
| assert len(p_num) == 1 | |
| image_tensor = model.process_images(image, model.config).to( | |
| dtype=model.dtype, device="cuda" | |
| ) | |
| all_visual_tensor.append(image_tensor) | |
| input_mode = 'image' | |
| qs += DEFAULT_IMAGE_TOKEN * p_num[0] + '\n' | |
| elif is_video(q[0]): | |
| video_frames, slice_len = _get_rawvideo_dec( | |
| q[0], | |
| image_processor, | |
| max_frames=MAX_IMAGE_LENGTH, | |
| video_framerate=1, | |
| image_aspect_ratio=getattr(model.config, "image_aspect_ratio", None), | |
| ) | |
| image_tensor = video_frames.half().cuda() | |
| all_visual_tensor.append(image_tensor) | |
| input_mode = 'video' | |
| qs += DEFAULT_IMAGE_TOKEN * slice_len + '\n' | |
| elif is_wav(q[0]): | |
| if a is not None and a.startswith('☜'): | |
| continue | |
| else: | |
| all_audio_path.append(q[0]) | |
| new_q = qs + DEFAULT_AUDIO_TOKEN | |
| qs = '' | |
| conv.append_message(conv.roles[0], new_q) | |
| conv.append_message(conv.roles[1], a) | |
| else: | |
| new_q = qs + q | |
| qs = '' | |
| conv.append_message(conv.roles[0], new_q) | |
| conv.append_message(conv.roles[1], a) | |
| if qs: | |
| conv.append_message(conv.roles[0], qs) | |
| conv.append_message(conv.roles[1], None) | |
| prompt = conv.get_prompt(input_mode) | |
| if all_audio_path: | |
| # 处理多个音频并合并 | |
| all_audio_features = [] | |
| all_audio_lengths = [] | |
| all_audio_for_llm_lens = [] | |
| for audio_path in all_audio_path: | |
| audio, audio_for_llm_lens = audio_processor.process(os.path.join(audio_path)) | |
| all_audio_features.append(audio) | |
| all_audio_lengths.append(audio.shape[0]) | |
| all_audio_for_llm_lens.append(audio_for_llm_lens) | |
| # 合并音频特征 | |
| combined_audio = torch.cat(all_audio_features, dim=0) | |
| combined_audio = torch.unsqueeze(combined_audio, dim=0) | |
| # 合并长度信息 | |
| combined_length = torch.tensor(sum(all_audio_lengths)) | |
| combined_length = torch.unsqueeze(combined_length, dim=0) | |
| # 合并LLM长度 | |
| combined_for_llm_lens = torch.tensor(sum(all_audio_for_llm_lens)) | |
| combined_for_llm_lens = torch.unsqueeze(combined_for_llm_lens, dim=0) | |
| audios = dict() | |
| audios["audios"] = combined_audio.half().cuda() | |
| audios["lengths"] = combined_length.half().cuda() | |
| audios["lengths_for_llm"] = combined_for_llm_lens.cuda() | |
| input_ids = ( | |
| tokenizer_image_audio_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") | |
| .unsqueeze(0) | |
| .cuda() | |
| ) | |
| else: | |
| # 空音频处理 | |
| audio = torch.zeros(400, 80) | |
| audio_length = audio.shape[0] | |
| audio_for_llm_lens = 60 | |
| audio = torch.unsqueeze(audio, dim=0) | |
| audio_length = torch.unsqueeze(torch.tensor(audio_length), dim=0) | |
| audio_for_llm_lens = torch.unsqueeze(torch.tensor(audio_for_llm_lens), dim=0) | |
| audios = dict() | |
| audios["audios"] = audio.half().cuda() | |
| audios["lengths"] = audio_length.half().cuda() | |
| audios["lengths_for_llm"] = audio_for_llm_lens.cuda() | |
| input_ids = ( | |
| tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") | |
| .unsqueeze(0) | |
| .cuda() | |
| ) | |
| if len(all_visual_tensor) > 0: | |
| all_visual_tensor = torch.cat(all_visual_tensor, dim=0) | |
| else: | |
| all_visual_tensor = torch.zeros((1, 3, 448, 448)).to(dtype=model.dtype, device="cuda") | |
| if type(all_visual_tensor) is list: | |
| print("all_visual_tensor is a list: ", len(all_visual_tensor)) | |
| if type(all_visual_tensor) is torch.Tensor: | |
| print("all_visual_tensor is a tensor: ", all_visual_tensor.shape) | |
| # 停止条件设置 | |
| stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | |
| keywords = [stop_str] | |
| stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) | |
| # 生成文本 | |
| start_time = time.time() | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| input_ids, | |
| images=all_visual_tensor, | |
| audios=audios, | |
| do_sample=False, | |
| temperature=0.01, | |
| top_p=None, | |
| num_beams=1, | |
| output_scores=True, | |
| return_dict_in_generate=True, | |
| max_new_tokens=1024, | |
| use_cache=True, | |
| stopping_criteria=[stopping_criteria], | |
| shared_v_pid_stride=None, | |
| ) | |
| infer_time = time.time() - start_time | |
| output_ids = output_ids.sequences | |
| input_token_len = input_ids.shape[1] | |
| outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=False)[0] | |
| outputs = outputs.strip() | |
| if outputs.endswith(stop_str): | |
| outputs = outputs[: -len(stop_str)] | |
| outputs = outputs.strip() | |
| print(f"Generated output: {outputs}") | |
| print(f"Time consumed: {infer_time}") | |
| task_history[-1] = (chat_query, outputs) | |
| remove_special_characters_output = remove_special_characters(outputs) | |
| _chatbot[-1] = (chat_query, _parse_text(remove_special_characters_output)) | |
| print("query",chat_query) | |
| print("task_history",task_history) | |
| print(_chatbot) | |
| print("answer: ",outputs) | |
| yield _chatbot | |
| def add_text(history, task_history, text): | |
| task_text = text | |
| if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION: | |
| task_text = text[:-1] | |
| history = history + [(_parse_text(text), None)] | |
| task_history = task_history + [(task_text, None)] | |
| return history, task_history, "" | |
| def add_file(history, task_history, file): | |
| history = history + [((file.name,), None)] | |
| task_history = task_history + [((file.name,), None)] | |
| return history, task_history | |
| def add_audio(history, task_history, file): | |
| print(file) | |
| if file is None: | |
| return history, task_history | |
| history = history + [((file,), None)] | |
| task_history = task_history + [((file,), None)] | |
| return history, task_history | |
| def add_video(history, task_history, file): | |
| print(file) | |
| if file is None: | |
| return history, task_history | |
| new_file_name = file.replace(".webm",".mp4") | |
| if file.endswith(".webm"): | |
| convert_webm_to_mp4(file, new_file_name) | |
| history = history + [((new_file_name,), None)] | |
| task_history = task_history + [((new_file_name,), None)] | |
| print("add_video", history, task_history) | |
| return history, task_history | |
| def reset_user_input(): | |
| return gr.update(value="") | |
| def reset_state(task_history): | |
| task_history.clear() | |
| return [] | |
| def stream_audio_output(history, task_history): | |
| print("stream_audio_output", history, task_history) | |
| text = history[-1][-1] | |
| text = text.replace("<br>", "") | |
| print("text", text) | |
| if not text: | |
| # import pdb;pdb.set_trace() | |
| yield None, None | |
| return | |
| llm_resounse = replace_equation(remove_special_characters(text)) | |
| #print('tts_text', llm_resounse) | |
| for idx, text in enumerate(split_into_sentences(llm_resounse)): | |
| embeddings = llm_embedding(torch.tensor(tokenizer.encode(text)).cuda()) | |
| for seg in tts.run(embeddings.reshape(-1, 896).unsqueeze(0), decoder_topk, | |
| None, | |
| codec_chunk_size, codec_padding_size): | |
| if idx == 0: | |
| try: | |
| split_idx = torch.nonzero(seg.abs() > 0.03, as_tuple=True)[-1][0] | |
| seg = seg[:, :, split_idx:] | |
| except: | |
| print('Do not need to split') | |
| pass | |
| if seg is not None and len(seg) > 0: | |
| seg = seg.to(torch.float32).cpu().numpy() | |
| yield 24000, float_to_int16(seg).T | |
| with gr.Blocks(title="VideoMLLM") as demo: | |
| gr.Markdown("""<center><font size=8>VITA</center>""") | |
| chatbot = gr.Chatbot(label='VITA', elem_classes="control-height", height=500) | |
| query = gr.Textbox(lines=2, label='Text Input') | |
| task_history = gr.State([]) | |
| with gr.Row(): | |
| add_text_button = gr.Button("Submit Text (提交文本)") | |
| add_audio_button = gr.Button("Submit Audio (提交音频)") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| addfile_btn = gr.UploadButton("📁 Upload (上传文件[视频,图片])", file_types=["video", "image"]) | |
| video_input = gr.Video(sources=[ "webcam"], height=400, width=700, container=True, interactive=True, show_download_button=True, label="📹 Video Recording (视频录制)") | |
| with gr.Column(scale=1): | |
| empty_bin = gr.Button("🧹 Clear History (清除历史)") | |
| record_btn = gr.Audio(sources=[ "microphone","upload"], type="filepath", label="🎤 Record or Upload Audio (录音或上传音频)", show_download_button=True, waveform_options=gr.WaveformOptions(sample_rate=16000)) | |
| audio_output = gr.Audio( | |
| label="Output Audio", | |
| value=None, | |
| format= "wav", | |
| autoplay=True, | |
| streaming=True, | |
| interactive=False, | |
| show_label=True, | |
| waveform_options=gr.WaveformOptions( | |
| sample_rate=24000, | |
| ), | |
| ) | |
| add_text_button.click(add_text, [chatbot, task_history, query], [chatbot, task_history], show_progress=True).then( | |
| reset_user_input, [], [query] | |
| ).then( | |
| predict, [chatbot, task_history], [chatbot], show_progress=True | |
| ).then( | |
| stream_audio_output,[chatbot, task_history], [audio_output], | |
| ) | |
| video_input.stop_recording(add_video, [chatbot, task_history, video_input], [chatbot, task_history], show_progress=True) | |
| empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True) | |
| addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True) | |
| add_audio_button.click(add_audio, [chatbot, task_history,record_btn], [chatbot, task_history], show_progress=True).then( | |
| predict, [chatbot, task_history], [chatbot], show_progress=True | |
| ).then( | |
| stream_audio_output,[chatbot, task_history], [audio_output], | |
| ) | |
| demo.launch() | |