Spaces:
Running
Running
使用CPU
Browse files
app.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import
|
|
|
|
| 3 |
import torch
|
| 4 |
-
from gradio_pdf import PDF
|
| 5 |
-
from pdf2image import convert_from_path
|
| 6 |
from PIL import Image
|
| 7 |
from transformers import AutoModelForImageTextToText, AutoProcessor, AutoTokenizer
|
| 8 |
|
|
@@ -13,8 +12,7 @@ print("Loading Nanonets OCR model...")
|
|
| 13 |
model = AutoModelForImageTextToText.from_pretrained(
|
| 14 |
model_path,
|
| 15 |
torch_dtype="auto",
|
| 16 |
-
device_map="
|
| 17 |
-
attn_implementation="flash_attention_2",
|
| 18 |
)
|
| 19 |
model.eval()
|
| 20 |
|
|
@@ -23,7 +21,6 @@ processor = AutoProcessor.from_pretrained(model_path)
|
|
| 23 |
print("Model loaded successfully!")
|
| 24 |
|
| 25 |
|
| 26 |
-
@spaces.GPU()
|
| 27 |
def ocr_image_gradio(image, max_tokens=4096):
|
| 28 |
"""Process image through Nanonets OCR model for Gradio interface"""
|
| 29 |
if image is None:
|
|
@@ -70,30 +67,24 @@ def ocr_image_gradio(image, max_tokens=4096):
|
|
| 70 |
return output_text[0]
|
| 71 |
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
)
|
| 91 |
-
page_text = ocr_image_gradio(image, max_tokens)
|
| 92 |
-
all_text.append(f"--- PAGE {i + 1} ---\n{page_text}\n")
|
| 93 |
-
|
| 94 |
-
# Combine results
|
| 95 |
-
combined_text = "\n".join(all_text)
|
| 96 |
-
return combined_text
|
| 97 |
|
| 98 |
|
| 99 |
# Create Gradio interface
|
|
@@ -125,51 +116,55 @@ with gr.Blocks(title="Nanonets OCR Demo") as demo:
|
|
| 125 |
with gr.Row():
|
| 126 |
with gr.Column(scale=1):
|
| 127 |
image_input = gr.Image(
|
| 128 |
-
label="
|
| 129 |
)
|
| 130 |
image_max_tokens = gr.Slider(
|
| 131 |
minimum=1024,
|
| 132 |
maximum=8192,
|
| 133 |
value=4096,
|
| 134 |
step=512,
|
| 135 |
-
label="
|
| 136 |
-
info="
|
| 137 |
)
|
| 138 |
image_extract_btn = gr.Button(
|
| 139 |
-
"
|
| 140 |
)
|
| 141 |
|
| 142 |
with gr.Column(scale=2):
|
| 143 |
image_output_text = gr.Textbox(
|
| 144 |
-
label="
|
| 145 |
lines=20,
|
| 146 |
show_copy_button=True,
|
| 147 |
-
placeholder="
|
| 148 |
)
|
| 149 |
|
| 150 |
-
#
|
| 151 |
-
with gr.TabItem("
|
| 152 |
with gr.Row():
|
| 153 |
with gr.Column(scale=1):
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
minimum=1024,
|
| 157 |
maximum=8192,
|
| 158 |
value=4096,
|
| 159 |
step=512,
|
| 160 |
-
label="
|
| 161 |
-
info="
|
| 162 |
)
|
| 163 |
-
|
| 164 |
-
"
|
| 165 |
)
|
| 166 |
|
| 167 |
with gr.Column(scale=2):
|
| 168 |
-
|
| 169 |
-
label="
|
| 170 |
lines=20,
|
| 171 |
show_copy_button=True,
|
| 172 |
-
placeholder="
|
| 173 |
)
|
| 174 |
|
| 175 |
# Event handlers for Image tab
|
|
@@ -187,43 +182,99 @@ with gr.Blocks(title="Nanonets OCR Demo") as demo:
|
|
| 187 |
show_progress=True,
|
| 188 |
)
|
| 189 |
|
| 190 |
-
# Event handlers for
|
| 191 |
-
|
| 192 |
-
fn=
|
| 193 |
-
inputs=[
|
| 194 |
-
outputs=
|
| 195 |
show_progress=True,
|
| 196 |
)
|
| 197 |
|
| 198 |
# Add model information section
|
| 199 |
-
with gr.Accordion("
|
| 200 |
gr.Markdown("""
|
| 201 |
## Nanonets-OCR-s
|
| 202 |
|
| 203 |
-
Nanonets-OCR-s
|
| 204 |
-
|
| 205 |
-
|
|
|
|
| 206 |
|
| 207 |
-
|
|
|
|
| 208 |
|
| 209 |
-
-
|
| 210 |
-
|
| 211 |
|
| 212 |
-
-
|
| 213 |
-
|
| 214 |
-
style, and context.
|
| 215 |
|
| 216 |
-
-
|
| 217 |
-
This is crucial for processing legal and business documents.
|
| 218 |
|
| 219 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
-
|
| 222 |
-
for consistent and reliable processing.
|
| 223 |
|
| 224 |
-
-
|
| 225 |
-
|
|
|
|
|
|
|
| 226 |
""")
|
| 227 |
|
| 228 |
if __name__ == "__main__":
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import base64
|
| 3 |
+
import io
|
| 4 |
import torch
|
|
|
|
|
|
|
| 5 |
from PIL import Image
|
| 6 |
from transformers import AutoModelForImageTextToText, AutoProcessor, AutoTokenizer
|
| 7 |
|
|
|
|
| 12 |
model = AutoModelForImageTextToText.from_pretrained(
|
| 13 |
model_path,
|
| 14 |
torch_dtype="auto",
|
| 15 |
+
device_map="cpu", # 使用CPU
|
|
|
|
| 16 |
)
|
| 17 |
model.eval()
|
| 18 |
|
|
|
|
| 21 |
print("Model loaded successfully!")
|
| 22 |
|
| 23 |
|
|
|
|
| 24 |
def ocr_image_gradio(image, max_tokens=4096):
|
| 25 |
"""Process image through Nanonets OCR model for Gradio interface"""
|
| 26 |
if image is None:
|
|
|
|
| 67 |
return output_text[0]
|
| 68 |
|
| 69 |
|
| 70 |
+
def ocr_base64_image(base64_string, max_tokens=4096):
|
| 71 |
+
"""Process base64 encoded image through Nanonets OCR model"""
|
| 72 |
+
if not base64_string or base64_string.strip() == "":
|
| 73 |
+
return "Please provide a valid base64 image string."
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
# Remove data URL prefix if present
|
| 77 |
+
if "base64," in base64_string:
|
| 78 |
+
base64_string = base64_string.split("base64,")[1]
|
| 79 |
+
|
| 80 |
+
# Decode base64 to image
|
| 81 |
+
image_data = base64.b64decode(base64_string)
|
| 82 |
+
image = Image.open(io.BytesIO(image_data))
|
| 83 |
+
|
| 84 |
+
# Process image using existing OCR function
|
| 85 |
+
return ocr_image_gradio(image, max_tokens)
|
| 86 |
+
except Exception as e:
|
| 87 |
+
return f"Error processing base64 image: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
|
| 90 |
# Create Gradio interface
|
|
|
|
| 116 |
with gr.Row():
|
| 117 |
with gr.Column(scale=1):
|
| 118 |
image_input = gr.Image(
|
| 119 |
+
label="上传文档图片", type="pil", height=400
|
| 120 |
)
|
| 121 |
image_max_tokens = gr.Slider(
|
| 122 |
minimum=1024,
|
| 123 |
maximum=8192,
|
| 124 |
value=4096,
|
| 125 |
step=512,
|
| 126 |
+
label="最大Token数",
|
| 127 |
+
info="生成的最大token数量",
|
| 128 |
)
|
| 129 |
image_extract_btn = gr.Button(
|
| 130 |
+
"提取文本", variant="primary", size="lg"
|
| 131 |
)
|
| 132 |
|
| 133 |
with gr.Column(scale=2):
|
| 134 |
image_output_text = gr.Textbox(
|
| 135 |
+
label="提取的文本",
|
| 136 |
lines=20,
|
| 137 |
show_copy_button=True,
|
| 138 |
+
placeholder="提取的文本将显示在这里...",
|
| 139 |
)
|
| 140 |
|
| 141 |
+
# Base64 Image tab
|
| 142 |
+
with gr.TabItem("Base64图片OCR"):
|
| 143 |
with gr.Row():
|
| 144 |
with gr.Column(scale=1):
|
| 145 |
+
base64_input = gr.Textbox(
|
| 146 |
+
label="输入Base64编码的图片",
|
| 147 |
+
lines=10,
|
| 148 |
+
placeholder="粘贴Base64编码的图片数据...",
|
| 149 |
+
)
|
| 150 |
+
base64_max_tokens = gr.Slider(
|
| 151 |
minimum=1024,
|
| 152 |
maximum=8192,
|
| 153 |
value=4096,
|
| 154 |
step=512,
|
| 155 |
+
label="最大Token数",
|
| 156 |
+
info="生成的最大token数量",
|
| 157 |
)
|
| 158 |
+
base64_extract_btn = gr.Button(
|
| 159 |
+
"提取文本", variant="primary", size="lg"
|
| 160 |
)
|
| 161 |
|
| 162 |
with gr.Column(scale=2):
|
| 163 |
+
base64_output_text = gr.Textbox(
|
| 164 |
+
label="提取的文本",
|
| 165 |
lines=20,
|
| 166 |
show_copy_button=True,
|
| 167 |
+
placeholder="提取的文本将显示在这里...",
|
| 168 |
)
|
| 169 |
|
| 170 |
# Event handlers for Image tab
|
|
|
|
| 182 |
show_progress=True,
|
| 183 |
)
|
| 184 |
|
| 185 |
+
# Event handlers for Base64 tab
|
| 186 |
+
base64_extract_btn.click(
|
| 187 |
+
fn=ocr_base64_image,
|
| 188 |
+
inputs=[base64_input, base64_max_tokens],
|
| 189 |
+
outputs=base64_output_text,
|
| 190 |
show_progress=True,
|
| 191 |
)
|
| 192 |
|
| 193 |
# Add model information section
|
| 194 |
+
with gr.Accordion("关于 Nanonets-OCR-s", open=False):
|
| 195 |
gr.Markdown("""
|
| 196 |
## Nanonets-OCR-s
|
| 197 |
|
| 198 |
+
Nanonets-OCR-s 是一个强大的最先进的图像到markdown的OCR模型,远超传统的文本提取功能。
|
| 199 |
+
它将文档转换为带有智能内容识别和语义标记的结构化markdown,非常适合大型语言模型(LLM)的下游处理。
|
| 200 |
+
|
| 201 |
+
### 主要特点
|
| 202 |
|
| 203 |
+
- **LaTeX公式识别**:自动将数学公式转换为格式正确的LaTeX语法。
|
| 204 |
+
它区分内联($...$)和显示($$...$$)公式。
|
| 205 |
|
| 206 |
+
- **智能图像描述**:使用结构化的`<img>`标签描述文档中的图像,使它们易于LLM处理。
|
| 207 |
+
它可以描述各种图像类型,包括徽标、图表、图形等,详细说明它们的内容、风格和上下文。
|
| 208 |
|
| 209 |
+
- **签名检测与隔离**:识别并隔离签名与其他文本,将其输出在`<signature>`标签内。
|
| 210 |
+
这对处理法律和商业文件至关重要。
|
|
|
|
| 211 |
|
| 212 |
+
- **水印提取**:检测并提取文档中的水印文本,将其放在`<watermark>`标签内。
|
|
|
|
| 213 |
|
| 214 |
+
- **智能复选框处理**:将表单复选框和单选按钮转换为标准化的Unicode符号(☐, ☑, ☒),
|
| 215 |
+
以实现一致可靠的处理。
|
| 216 |
+
|
| 217 |
+
- **复杂表格提取**:准确地从文档中提取复杂表格,并将它们转换为markdown和HTML表格格式。
|
| 218 |
+
""")
|
| 219 |
+
|
| 220 |
+
# API Usage Information
|
| 221 |
+
with gr.Accordion("API使用说明", open=True):
|
| 222 |
+
gr.Markdown("""
|
| 223 |
+
## API使用方法
|
| 224 |
+
|
| 225 |
+
### Base64图片识别API
|
| 226 |
+
|
| 227 |
+
您可以通过HTTP POST请求使用Base64图片识别API:
|
| 228 |
+
|
| 229 |
+
```
|
| 230 |
+
curl -X POST "http://localhost:7860/api/predict" \\
|
| 231 |
+
-H "Content-Type: application/json" \\
|
| 232 |
+
-d '{
|
| 233 |
+
"fn_index": 1,
|
| 234 |
+
"data": [
|
| 235 |
+
"YOUR_BASE64_STRING_HERE",
|
| 236 |
+
4096
|
| 237 |
+
]
|
| 238 |
+
}'
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
- `fn_index: 1` 对应Base64图片OCR功能
|
| 242 |
+
- 第一个参数是Base64编码的图片字符串
|
| 243 |
+
- 第二个参数是最大token数量
|
| 244 |
+
|
| 245 |
+
### 普通图片上传API
|
| 246 |
+
|
| 247 |
+
```
|
| 248 |
+
curl -X POST "http://localhost:7860/api/predict" \\
|
| 249 |
+
-H "Content-Type: application/json" \\
|
| 250 |
+
-d '{
|
| 251 |
+
"fn_index": 0,
|
| 252 |
+
"data": [
|
| 253 |
+
"IMAGE_DATA_HERE",
|
| 254 |
+
4096
|
| 255 |
+
]
|
| 256 |
+
}'
|
| 257 |
+
```
|
| 258 |
+
|
| 259 |
+
- `fn_index: 0` 对应普通图片OCR功能
|
| 260 |
+
""")
|
| 261 |
+
|
| 262 |
+
# CPU Usage Warning
|
| 263 |
+
with gr.Accordion("CPU环境说明", open=True):
|
| 264 |
+
gr.Markdown("""
|
| 265 |
+
## CPU环境性能说明
|
| 266 |
|
| 267 |
+
此应用程序当前运行在CPU环境下(2核16G),请注意:
|
|
|
|
| 268 |
|
| 269 |
+
- 处理大型图像可能需要更长时间
|
| 270 |
+
- 建议使用较小的图像以获得更快的响应速度
|
| 271 |
+
- 如果处理时间过长,可以考虑降低最大Token数
|
| 272 |
+
- 模型已针对CPU环境进行了优化配置
|
| 273 |
""")
|
| 274 |
|
| 275 |
if __name__ == "__main__":
|
| 276 |
+
import torch
|
| 277 |
+
print(f"使用设备: CPU - 可用线程数: {torch.get_num_threads()}")
|
| 278 |
+
# 设置线程数以优化CPU性能
|
| 279 |
+
torch.set_num_threads(2) # 设置为可用的2核
|
| 280 |
+
demo.queue().launch(share=True, server_name="0.0.0.0")
|