import os
import torch
import numpy as np
from PIL import Image
import spaces
from transformers import AutoProcessor
from qwen_vl_utils import process_vision_info # 请确保该模块在你的环境可用
from transformers import HunYuanVLForConditionalGeneration
import gradio as gr
from argparse import ArgumentParser
import copy
import requests
from io import BytesIO
import tempfile
import hashlib
import gc
def _get_args():
parser = ArgumentParser()
parser.add_argument('-c',
'--checkpoint-path',
type=str,
default='tencent/HunyuanOCR',
help='Checkpoint name or path, default to %(default)r')
parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only')
parser.add_argument('--flash-attn2',
action='store_true',
default=False,
help='Enable flash_attention_2 when loading the model.')
parser.add_argument('--share',
action='store_true',
default=False,
help='Create a publicly shareable link for the interface.')
parser.add_argument('--inbrowser',
action='store_true',
default=False,
help='Automatically launch the interface in a new tab on the default browser.')
args = parser.parse_args()
return args
def _load_model_processor(args):
# 优化:尝试使用 flash_attention_2 或 sdpa
try:
attn_impl = "flash_attention_2"
print(f"[INFO] 尝试使用 {attn_impl}")
model = HunYuanVLForConditionalGeneration.from_pretrained(
args.checkpoint_path,
attn_implementation=attn_impl,
torch_dtype=torch.bfloat16,
device_map="cuda",
token=os.environ.get('HF_TOKEN')
)
except Exception as e:
print(f"[WARNING] flash_attention_2 不可用: {e}")
print(f"[INFO] 降级使用 sdpa")
try:
model = HunYuanVLForConditionalGeneration.from_pretrained(
args.checkpoint_path,
attn_implementation="sdpa",
torch_dtype=torch.bfloat16,
device_map="cuda",
token=os.environ.get('HF_TOKEN')
)
except Exception as e2:
print(f"[WARNING] sdpa 不可用: {e2}")
print(f"[INFO] 使用 eager (最慢)")
model = HunYuanVLForConditionalGeneration.from_pretrained(
args.checkpoint_path,
attn_implementation="eager",
torch_dtype=torch.bfloat16,
device_map="cuda",
token=os.environ.get('HF_TOKEN')
)
processor = AutoProcessor.from_pretrained(args.checkpoint_path, use_fast=False, trust_remote_code=True)
return model, processor
def _parse_text(text):
"""解析文本,处理特殊格式"""
# if text is None:
# return text
text = text.replace("|', '', text)
# return text
return text
def _gc():
"""垃圾回收"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _launch_demo(args, model, processor):
# 关键修复:减少 duration,添加调试信息
@spaces.GPU(duration=60)
def call_local_model(model, processor, messages):
import time
start_time = time.time()
print(f"[DEBUG] 开始推理,时间: {start_time}")
print(f"[DEBUG] Messages: {messages}")
messages = [messages]
# 使用 processor 构造输入格式
texts = [
processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
for msg in messages
]
prep_time = time.time()
print(f"[DEBUG] 模板处理耗时: {prep_time - start_time:.2f}s")
image_inputs, video_inputs = process_vision_info(messages)
vision_time = time.time()
print(f"[DEBUG] 视觉处理耗时: {vision_time - prep_time:.2f}s")
inputs = processor(
text=texts,
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
input_time = time.time()
print(f"[DEBUG] 输入处理耗时: {input_time - vision_time:.2f}s")
print(f"[DEBUG] Input shape: {inputs.input_ids.shape if 'input_ids' in inputs else 'N/A'}")
# 关键修复1: 大幅减少 max_new_tokens
# 关键修复2: 添加 EOS token 和停止条件
# 关键修复3: 添加超时保护
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=512, # 从 8192 降到 512,避免无限生成
repetition_penalty=1.03,
do_sample=False,
# 关键:设置 EOS token,确保能正常停止
eos_token_id=processor.tokenizer.eos_token_id,
pad_token_id=processor.tokenizer.pad_token_id,
# 添加提前停止条件
use_cache=True,
)
gen_time = time.time()
print(f"[DEBUG] 生成耗时: {gen_time - input_time:.2f}s")
print(f"[DEBUG] Generated shape: {generated_ids.shape}")
# 解码输出
if "input_ids" in inputs:
input_ids = inputs.input_ids
else:
input_ids = inputs.inputs # fallback
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(input_ids, generated_ids)
]
print(f"[DEBUG] Trimmed tokens count: {[len(ids) for ids in generated_ids_trimmed]}")
output_texts = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
decode_time = time.time()
print(f"[DEBUG] 解码耗时: {decode_time - gen_time:.2f}s")
print(f"[DEBUG] 总耗时: {decode_time - start_time:.2f}s")
print(f"[DEBUG] Output: {output_texts[0][:200]}...") # 只打印前200字符
return output_texts
def create_predict_fn():
def predict(_chatbot, task_history):
nonlocal model, processor
chat_query = _chatbot[-1][0]
query = task_history[-1][0]
if len(chat_query) == 0:
_chatbot.pop()
task_history.pop()
return _chatbot
print('User: ', query)
history_cp = copy.deepcopy(task_history)
full_response = ''
messages = []
content = []
for q, a in history_cp:
if isinstance(q, (tuple, list)):
# 判断是URL还是本地路径
img_path = q[0]
if img_path.startswith(('http://', 'https://')):
content.append({'type': 'image', 'image': img_path})
else:
content.append({'type': 'image', 'image': f'{os.path.abspath(img_path)}'})
else:
content.append({'type': 'text', 'text': q})
messages.append({'role': 'user', 'content': content})
messages.append({'role': 'assistant', 'content': [{'type': 'text', 'text': a}]})
content = []
messages.pop()
# 调用模型获取响应
response_list = call_local_model(model, processor, messages)
response = response_list[0] if response_list else ""
_chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response)))
full_response = _parse_text(response)
task_history[-1] = (query, full_response)
print('HunyuanOCR: ' + _parse_text(full_response))
yield _chatbot
return predict
def create_regenerate_fn():
def regenerate(_chatbot, task_history):
nonlocal model, processor
if not task_history:
return _chatbot
item = task_history[-1]
if item[1] is None:
return _chatbot
task_history[-1] = (item[0], None)
chatbot_item = _chatbot.pop(-1)
if chatbot_item[0] is None:
_chatbot[-1] = (_chatbot[-1][0], None)
else:
_chatbot.append((chatbot_item[0], None))
# 使用外层的predict函数
_chatbot_gen = predict(_chatbot, task_history)
for _chatbot in _chatbot_gen:
yield _chatbot
return regenerate
predict = create_predict_fn()
regenerate = create_regenerate_fn()
def add_text(history, task_history, text):
task_text = text
history = history if history is not None else []
task_history = task_history if task_history is not None else []
history = history + [(_parse_text(text), None)]
task_history = task_history + [(task_text, None)]
return history, task_history, ''
def add_file(history, task_history, file):
history = history if history is not None else []
task_history = task_history if task_history is not None else []
history = history + [((file.name,), None)]
task_history = task_history + [((file.name,), None)]
return history, task_history
def download_url_image(url):
"""下载 URL 图片到本地临时文件"""
try:
# 使用 URL 的哈希值作为文件名,避免重复下载
url_hash = hashlib.md5(url.encode()).hexdigest()
temp_dir = tempfile.gettempdir()
temp_path = os.path.join(temp_dir, f"hyocr_demo_{url_hash}.png")
# 如果文件已存在,直接返回
if os.path.exists(temp_path):
return temp_path
# 下载图片
response = requests.get(url, timeout=10)
response.raise_for_status()
with open(temp_path, 'wb') as f:
f.write(response.content)
return temp_path
except Exception as e:
print(f"下载图片失败: {url}, 错误: {e}")
return url # 失败时返回原 URL
def reset_user_input():
return gr.update(value='')
def reset_state(_chatbot, task_history):
task_history.clear()
_chatbot.clear()
_gc()
return []
# 示例图片路径配置 - 请替换为实际图片路径
EXAMPLE_IMAGES = {
"spotting": "https://hunyuan-multimodal-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuan_multimodal/mllm_data/23cc43af9376b948f3febaf4ce854a8a.jpg?q-sign-algorithm=sha1&q-ak=AKIDbLEFMUYZgyERZnygUQLC7xkQ1hTAzulX&q-sign-time=1763523817%3B1794627877&q-key-time=1763523817%3B1794627877&q-header-list=host&q-url-param-list=&q-signature=8ebd6a9d3ed7eba73bb783c337349db9c29972e2", # TODO: 替换为场景文字示例图片路径
"parsing": "https://hunyuan-multimodal-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuan_multimodal/mllm_data/c4997ebd1be9f7c3e002fabba8b46cb7.jpg?q-sign-algorithm=sha1&q-ak=AKIDbLEFMUYZgyERZnygUQLC7xkQ1hTAzulX&q-sign-time=1763523818%3B1794627878&q-key-time=1763523818%3B1794627878&q-header-list=host&q-url-param-list=&q-signature=d2cd12be4c7902821c8c82203e4642624046911a",
"ie": "https://hunyuan-multimodal-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuan_multimodal/mllm_data/7c67c0f78e4423d51644a325da1f8e85.jpg?q-sign-algorithm=sha1&q-ak=AKIDbLEFMUYZgyERZnygUQLC7xkQ1hTAzulX&q-sign-time=1763523818%3B1794627878&q-key-time=1763523818%3B1794627878&q-header-list=host&q-url-param-list=&q-signature=803648f3253706f654faf1423869fd9e00e7056e",
"vqa": "https://hunyuan-multimodal-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuan_multimodal/mllm_data/fea0865d1c70c53aaa2ab91cd0e787f5.jpg?q-sign-algorithm=sha1&q-ak=AKIDbLEFMUYZgyERZnygUQLC7xkQ1hTAzulX&q-sign-time=1763523818%3B1794627878&q-key-time=1763523818%3B1794627878&q-header-list=host&q-url-param-list=&q-signature=a92b94e298a11aea130d730d3b16ee761acc3f4c",
"translation": "https://hunyuan-multimodal-1258344703.cos.ap-guangzhou.myqcloud.com/hunyuan_multimodal/mllm_data/d1af99d35e9db9e820ebebb5bc68993a.jpg?q-sign-algorithm=sha1&q-ak=AKIDbLEFMUYZgyERZnygUQLC7xkQ1hTAzulX&q-sign-time=1763967603%3B1795071663&q-key-time=1763967603%3B1795071663&q-header-list=host&q-url-param-list=&q-signature=a57080c0b3d4c76ea74b88c6291f9004241c9d49",
# "spotting": "examples/spotting.jpg",
# "parsing": "examples/parsing.jpg",
# "ie": "examples/ie.jpg",
# "vqa": "examples/vqa.jpg",
# "translation": "examples/translation.jpg"
}
with gr.Blocks(css="""
body {
background: #f5f7fa;
}
.gradio-container {
max-width: 100% !important;
padding: 0 40px !important;
}
.header-section {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
padding: 30px 0;
margin: -20px -40px 30px -40px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
.header-content {
max-width: 1600px;
margin: 0 auto;
padding: 0 40px;
display: flex;
align-items: center;
gap: 20px;
}
.header-logo {
height: 60px;
}
.header-text h1 {
color: white;
font-size: 32px;
font-weight: bold;
margin: 0 0 5px 0;
}
.header-text p {
color: rgba(255,255,255,0.9);
margin: 0;
font-size: 14px;
}
.main-container {
max-width: 1800px;
margin: 0 auto;
}
.chatbot {
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08) !important;
border-radius: 12px !important;
border: 1px solid #e5e7eb !important;
background: white !important;
}
.input-panel {
background: white;
padding: 20px;
border-radius: 12px;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08);
border: 1px solid #e5e7eb;
}
.input-box textarea {
border: 2px solid #e5e7eb !important;
border-radius: 8px !important;
font-size: 14px !important;
}
.input-box textarea:focus {
border-color: #667eea !important;
}
.btn-primary {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
border: none !important;
color: white !important;
font-weight: 500 !important;
padding: 10px 24px !important;
font-size: 14px !important;
}
.btn-primary:hover {
transform: translateY(-1px) !important;
box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4) !important;
}
.btn-secondary {
background: white !important;
border: 2px solid #667eea !important;
color: #667eea !important;
padding: 8px 20px !important;
font-size: 14px !important;
}
.btn-secondary:hover {
background: #f0f4ff !important;
}
.example-grid {
display: grid;
grid-template-columns: repeat(4, 1fr);
gap: 20px;
margin-top: 30px;
}
.example-card {
background: white;
border-radius: 12px;
overflow: hidden;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08);
border: 1px solid #e5e7eb;
transition: all 0.3s ease;
}
.example-card:hover {
transform: translateY(-4px);
box-shadow: 0 8px 20px rgba(102, 126, 234, 0.15);
border-color: #667eea;
}
.example-image-wrapper {
width: 100%;
height: 180px;
overflow: hidden;
background: #f5f7fa;
}
.example-image-wrapper img {
width: 100%;
height: 100%;
object-fit: cover;
}
.example-btn {
width: 100% !important;
white-space: pre-wrap !important;
text-align: left !important;
padding: 16px !important;
background: white !important;
border: none !important;
border-top: 1px solid #e5e7eb !important;
color: #1f2937 !important;
font-size: 14px !important;
line-height: 1.6 !important;
transition: all 0.3s ease !important;
font-weight: 500 !important;
}
.example-btn:hover {
background: #f9fafb !important;
color: #667eea !important;
}
.feature-section {
background: white;
padding: 24px;
border-radius: 12px;
margin-top: 30px;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08);
border: 1px solid #e5e7eb;
}
.section-title {
font-size: 18px;
font-weight: 600;
color: #1f2937;
margin-bottom: 20px;
padding-bottom: 12px;
border-bottom: 2px solid #e5e7eb;
}
""") as demo:
# 顶部导航栏
gr.HTML("""
Powered by Tencent Hunyuan Team
© 2025 Tencent Hunyuan Team. All rights reserved.
本系统基于 HunyuanOCR 构建 | 仅供学习研究使用