|
|
import gradio as gr |
|
|
import torch |
|
|
import os |
|
|
|
|
|
|
|
|
from uni_moe.model.processing_qwen2_vl import Qwen2VLProcessor |
|
|
from uni_moe.model.modeling_out import GrinQwen2VLOutForConditionalGeneration |
|
|
from uni_moe.qwen_vl_utils import process_mm_info |
|
|
|
|
|
|
|
|
MODEL_ID = "HIT-TMG/Uni-MoE-2.0-Omni" |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
print(f"Loading model on {DEVICE}...") |
|
|
|
|
|
|
|
|
processor = Qwen2VLProcessor.from_pretrained(MODEL_ID) |
|
|
model = GrinQwen2VLOutForConditionalGeneration.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=torch.bfloat16 |
|
|
).to(DEVICE) |
|
|
|
|
|
processor.data_args = model.config |
|
|
|
|
|
def generate_response(text_input, image_path, audio_path): |
|
|
|
|
|
content = [] |
|
|
|
|
|
|
|
|
prompt_text = text_input |
|
|
|
|
|
if audio_path: |
|
|
content.append({"type": "audio", "audio": audio_path}) |
|
|
prompt_text = "<audio>\n" + prompt_text |
|
|
|
|
|
if image_path: |
|
|
content.append({"type": "image", "image": image_path}) |
|
|
prompt_text = "<image>\n" + prompt_text |
|
|
|
|
|
content.append({"type": "text", "text": prompt_text}) |
|
|
|
|
|
messages = [{ |
|
|
"role": "user", |
|
|
"content": content |
|
|
}] |
|
|
|
|
|
|
|
|
texts = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
|
|
|
texts = texts.replace("<image>","<|vision_start|><|image_pad|><|vision_end|>") \ |
|
|
.replace("<audio>","<|audio_start|><|audio_pad|><|audio_end|>") \ |
|
|
.replace("<video>","<|vision_start|><|video_pad|><|vision_end|>") |
|
|
|
|
|
|
|
|
image_inputs, video_inputs, audio_inputs = process_mm_info(messages) |
|
|
|
|
|
|
|
|
inputs = processor( |
|
|
text=texts, |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
audios=audio_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
|
|
|
if "input_ids" in inputs: |
|
|
inputs["input_ids"] = inputs["input_ids"].unsqueeze(0) |
|
|
|
|
|
inputs = inputs.to(device=model.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = model.generate( |
|
|
**inputs, |
|
|
use_cache=True, |
|
|
pad_token_id=processor.tokenizer.eos_token_id, |
|
|
max_new_tokens=2048, |
|
|
temperature=0.7, |
|
|
do_sample=True |
|
|
) |
|
|
|
|
|
|
|
|
response = processor.batch_decode(output_ids[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0] |
|
|
return response |
|
|
|
|
|
|
|
|
with gr.Interface( |
|
|
fn=generate_response, |
|
|
inputs=[ |
|
|
gr.Textbox(label="Question/Prompt", placeholder="Describe the image or audio..."), |
|
|
gr.Image(type="filepath", label="Upload Image (Optional)"), |
|
|
gr.Audio(type="filepath", label="Upload Audio (Optional)") |
|
|
], |
|
|
outputs=gr.Textbox(label="Uni-MoE Response"), |
|
|
title="Uni-MoE 2.0 Omni Demo", |
|
|
description="Upload an image or audio and ask questions about them using Uni-MoE 2.0." |
|
|
) as demo: |
|
|
demo.launch() |