|
|
import os |
|
|
import torch |
|
|
import gradio as gr |
|
|
import spaces |
|
|
from PIL import Image |
|
|
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = "openbmb/MiniCPM-o-2_6" |
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
"""تحميل النموذج عند الحاجة فقط""" |
|
|
global model, tokenizer |
|
|
|
|
|
if model is not None: |
|
|
return |
|
|
|
|
|
print(f"Loading {MODEL_ID}...") |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
try: |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
MODEL_ID, |
|
|
trust_remote_code=True, |
|
|
use_fast=False |
|
|
) |
|
|
|
|
|
|
|
|
model = AutoModel.from_pretrained( |
|
|
MODEL_ID, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=dtype, |
|
|
low_cpu_mem_usage=True, |
|
|
attn_implementation="eager", |
|
|
).eval() |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
model = model.cuda() |
|
|
|
|
|
print("Model loaded successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error with AutoModel, trying AutoModelForCausalLM: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=dtype, |
|
|
low_cpu_mem_usage=True, |
|
|
attn_implementation="eager" |
|
|
).eval() |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
model = model.cuda() |
|
|
|
|
|
print("Model loaded successfully with AutoModelForCausalLM!") |
|
|
|
|
|
except Exception as e2: |
|
|
print(f"Failed to load model: {e2}") |
|
|
raise RuntimeError(f"Could not load model: {e2}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_image(image_input): |
|
|
"""معالجة الصورة للنموذج""" |
|
|
if image_input is None: |
|
|
return None |
|
|
|
|
|
if isinstance(image_input, str): |
|
|
return Image.open(image_input).convert('RGB') |
|
|
else: |
|
|
return image_input.convert('RGB') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=60) |
|
|
def generate_response( |
|
|
text_input, |
|
|
image_input, |
|
|
temperature, |
|
|
top_p, |
|
|
max_new_tokens |
|
|
): |
|
|
""" |
|
|
معالجة النص والصور باستخدام MiniCPM-o-2_6 |
|
|
""" |
|
|
|
|
|
if not text_input and not image_input: |
|
|
return "Please provide text or image input." |
|
|
|
|
|
try: |
|
|
load_model() |
|
|
global model, tokenizer |
|
|
|
|
|
|
|
|
if image_input is not None: |
|
|
|
|
|
image = process_image(image_input) |
|
|
|
|
|
if not text_input: |
|
|
text_input = "What is shown in this image? Please describe in detail." |
|
|
|
|
|
|
|
|
if hasattr(model, 'chat'): |
|
|
try: |
|
|
|
|
|
msgs = [{"role": "user", "content": [image, text_input]}] |
|
|
|
|
|
with torch.no_grad(): |
|
|
response = model.chat( |
|
|
image=image, |
|
|
msgs=msgs, |
|
|
tokenizer=tokenizer, |
|
|
sampling=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
max_new_tokens=max_new_tokens |
|
|
) |
|
|
|
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Chat method failed: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt = f"Image: [Image will be processed]\n\nQuestion: {text_input}\n\nAnswer:" |
|
|
|
|
|
else: |
|
|
|
|
|
prompt = text_input |
|
|
|
|
|
|
|
|
inputs = tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=2048 |
|
|
) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
inputs = {k: v.cuda() for k, v in inputs.items() if v is not None} |
|
|
|
|
|
|
|
|
gen_kwargs = { |
|
|
"max_new_tokens": max_new_tokens, |
|
|
"temperature": temperature if temperature > 0 else 1e-7, |
|
|
"top_p": top_p, |
|
|
"do_sample": temperature > 0, |
|
|
"pad_token_id": tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, |
|
|
"eos_token_id": tokenizer.eos_token_id, |
|
|
} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate(**inputs, **gen_kwargs) |
|
|
|
|
|
|
|
|
response = tokenizer.decode( |
|
|
outputs[0][inputs['input_ids'].shape[1]:], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
return response.strip() |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clear_all(): |
|
|
"""مسح جميع المدخلات والمخرجات""" |
|
|
return "", None, "" |
|
|
|
|
|
|
|
|
def update_examples_visibility(show_examples): |
|
|
"""تحديث رؤية الأمثلة""" |
|
|
return gr.update(visible=show_examples) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_demo(): |
|
|
"""إنشاء واجهة Gradio البسيطة""" |
|
|
|
|
|
with gr.Blocks(title="MiniCPM-o-2.6", css=""" |
|
|
.gradio-container { |
|
|
max-width: 1200px; |
|
|
margin: auto; |
|
|
} |
|
|
h1 { |
|
|
text-align: center; |
|
|
} |
|
|
.contain { |
|
|
background: white; |
|
|
border-radius: 10px; |
|
|
padding: 20px; |
|
|
} |
|
|
""") as demo: |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
# 🤖 MiniCPM-o-2.6 - Multimodal AI Assistant |
|
|
|
|
|
<div style="text-align: center;"> |
|
|
<p> |
|
|
<b>8B parameters model</b> with GPT-4 level performance<br> |
|
|
Supports: Text Generation, Image Understanding, OCR, and Multi-lingual conversations |
|
|
</p> |
|
|
</div> |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
with gr.Group(): |
|
|
text_input = gr.Textbox( |
|
|
label="💭 Text Input", |
|
|
placeholder="Enter your question or prompt here...\nYou can ask about images, request text generation, or have a conversation.", |
|
|
lines=4, |
|
|
elem_id="text_input" |
|
|
) |
|
|
|
|
|
image_input = gr.Image( |
|
|
label="📷 Image Input (Optional)", |
|
|
type="pil", |
|
|
elem_id="image_input" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button( |
|
|
"🚀 Generate Response", |
|
|
variant="primary", |
|
|
scale=2 |
|
|
) |
|
|
clear_btn = gr.Button( |
|
|
"🗑️ Clear All", |
|
|
variant="secondary", |
|
|
scale=1 |
|
|
) |
|
|
|
|
|
output = gr.Textbox( |
|
|
label="🤖 AI Response", |
|
|
lines=10, |
|
|
interactive=False, |
|
|
elem_id="output" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Column(scale=1): |
|
|
with gr.Group(): |
|
|
gr.Markdown("### ⚙️ Generation Settings") |
|
|
|
|
|
temperature = gr.Slider( |
|
|
label="Temperature", |
|
|
minimum=0.0, |
|
|
maximum=1.5, |
|
|
value=0.7, |
|
|
step=0.1, |
|
|
info="Controls randomness (0=deterministic, 1.5=very creative)" |
|
|
) |
|
|
|
|
|
top_p = gr.Slider( |
|
|
label="Top-p (Nucleus Sampling)", |
|
|
minimum=0.1, |
|
|
maximum=1.0, |
|
|
value=0.9, |
|
|
step=0.05, |
|
|
info="Controls diversity of output" |
|
|
) |
|
|
|
|
|
max_new_tokens = gr.Slider( |
|
|
label="Max New Tokens", |
|
|
minimum=50, |
|
|
maximum=2048, |
|
|
value=512, |
|
|
step=50, |
|
|
info="Maximum length of generated response" |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
### 📚 Quick Tips: |
|
|
|
|
|
**Text Generation:** |
|
|
- Ask questions |
|
|
- Request explanations |
|
|
- Generate creative content |
|
|
|
|
|
**Image Understanding:** |
|
|
- Upload an image |
|
|
- Ask about contents |
|
|
- Request OCR/text extraction |
|
|
- Get detailed descriptions |
|
|
|
|
|
**Languages:** |
|
|
- English, Chinese, Arabic |
|
|
- And many more! |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("### 💡 Example Prompts") |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["Explain quantum computing in simple terms for a beginner.", None], |
|
|
["Write a short story about a robot learning to paint.", None], |
|
|
["What are the main differences between Python and JavaScript?", None], |
|
|
["Create a healthy meal plan for one week.", None], |
|
|
["Translate 'Hello, how are you?' to French, Spanish, and Arabic.", None], |
|
|
], |
|
|
inputs=[text_input, image_input], |
|
|
outputs=output, |
|
|
fn=lambda t, i: generate_response(t, i, 0.7, 0.9, 512), |
|
|
cache_examples=False, |
|
|
label="Click any example to try it" |
|
|
) |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
fn=generate_response, |
|
|
inputs=[text_input, image_input, temperature, top_p, max_new_tokens], |
|
|
outputs=output, |
|
|
api_name="generate" |
|
|
) |
|
|
|
|
|
text_input.submit( |
|
|
fn=generate_response, |
|
|
inputs=[text_input, image_input, temperature, top_p, max_new_tokens], |
|
|
outputs=output |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
fn=clear_all, |
|
|
inputs=[], |
|
|
outputs=[text_input, image_input, output] |
|
|
) |
|
|
|
|
|
|
|
|
demo.load( |
|
|
lambda: gr.Info("Model is loading... This may take a moment on first use."), |
|
|
inputs=None, |
|
|
outputs=None |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = create_demo() |
|
|
demo.launch( |
|
|
ssr_mode=False, |
|
|
show_error=True, |
|
|
share=False |
|
|
) |