Der11 / app.py
Derr11's picture
Update app.py
02a5a58 verified
import gradio as gr
import torch
import os
# استيراد المكتبات الخاصة بالنموذج (تأكد أن مجلد uni_moe موجود بجانب هذا الملف)
from uni_moe.model.processing_qwen2_vl import Qwen2VLProcessor
from uni_moe.model.modeling_out import GrinQwen2VLOutForConditionalGeneration
from uni_moe.qwen_vl_utils import process_mm_info
# إعداد النموذج
MODEL_ID = "HIT-TMG/Uni-MoE-2.0-Omni"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on {DEVICE}...")
# تحميل المعالج والنموذج
processor = Qwen2VLProcessor.from_pretrained(MODEL_ID)
model = GrinQwen2VLOutForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16
).to(DEVICE)
processor.data_args = model.config
def generate_response(text_input, image_path, audio_path):
# تجهيز محتوى الرسالة
content = []
# إضافة النص مع التاجات الخاصة إذا وجدت وسائط
prompt_text = text_input
if audio_path:
content.append({"type": "audio", "audio": audio_path})
prompt_text = "<audio>\n" + prompt_text
if image_path:
content.append({"type": "image", "image": image_path})
prompt_text = "<image>\n" + prompt_text
content.append({"type": "text", "text": prompt_text})
messages = [{
"role": "user",
"content": content
}]
# معالجة القوالب (Chat Template)
texts = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# استبدال التاجات الخاصة كما في المثال الأصلي
texts = texts.replace("<image>","<|vision_start|><|image_pad|><|vision_end|>") \
.replace("<audio>","<|audio_start|><|audio_pad|><|audio_end|>") \
.replace("<video>","<|vision_start|><|video_pad|><|vision_end|>")
# معالجة الوسائط
image_inputs, video_inputs, audio_inputs = process_mm_info(messages)
# تجهيز المدخلات للنموذج
inputs = processor(
text=texts,
images=image_inputs,
videos=video_inputs,
audios=audio_inputs,
padding=True,
return_tensors="pt",
)
# إضافة بعد جديد للـ inputs ونقلها للـ GPU
if "input_ids" in inputs:
inputs["input_ids"] = inputs["input_ids"].unsqueeze(0) # Unsqueeze كما في المثال
inputs = inputs.to(device=model.device)
# التوليد
with torch.no_grad():
output_ids = model.generate(
**inputs,
use_cache=True,
pad_token_id=processor.tokenizer.eos_token_id,
max_new_tokens=2048, # تم التقليل قليلاً لتسريع الاستجابة في الويب
temperature=0.7,
do_sample=True
)
# فك التشفير واستخراج النص فقط
response = processor.batch_decode(output_ids[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
return response
# بناء واجهة Gradio
with gr.Interface(
fn=generate_response,
inputs=[
gr.Textbox(label="Question/Prompt", placeholder="Describe the image or audio..."),
gr.Image(type="filepath", label="Upload Image (Optional)"),
gr.Audio(type="filepath", label="Upload Audio (Optional)")
],
outputs=gr.Textbox(label="Uni-MoE Response"),
title="Uni-MoE 2.0 Omni Demo",
description="Upload an image or audio and ask questions about them using Uni-MoE 2.0."
) as demo:
demo.launch()