| | import atexit |
| | import functools |
| | import base64 |
| | import io |
| | import re |
| | import os |
| | import tempfile |
| | from queue import Queue |
| | from threading import Event, Thread |
| | import numpy as np |
| | from paddleocr import PaddleOCR, draw_ocr |
| | from PIL import Image |
| | import gradio as gr |
| | import fasttext |
| |
|
| |
|
| | |
| | |
| | try: |
| | |
| | model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "lid.176.bin") |
| | if not os.path.exists(model_path): |
| | |
| | import urllib.request |
| | print("下载fasttext语言检测模型...") |
| | urllib.request.urlretrieve( |
| | "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin", |
| | model_path |
| | ) |
| | |
| | |
| | lang_model = fasttext.load_model(model_path) |
| | print("fasttext语言检测模型加载成功") |
| | except Exception as e: |
| | print(f"警告: 无法加载fasttext模型: {e}") |
| | lang_model = None |
| |
|
| |
|
| | LANG_CONFIG = { |
| | "ch": {"num_workers": 2}, |
| | "en": {"num_workers": 2}, |
| | "fr": {"num_workers": 1}, |
| | "german": {"num_workers": 1}, |
| | "korean": {"num_workers": 1}, |
| | "japan": {"num_workers": 1}, |
| | } |
| |
|
| | |
| | LANG_MAP = { |
| | "ch": "中文", |
| | "en": "英文", |
| | "fr": "法语", |
| | "german": "德语", |
| | "korean": "韩语", |
| | "japan": "日语", |
| | } |
| |
|
| | |
| | FASTTEXT_TO_PADDLE = { |
| | "zh": "ch", |
| | "en": "en", |
| | "fr": "fr", |
| | "de": "german", |
| | "ko": "korean", |
| | "ja": "japan", |
| | } |
| |
|
| | |
| | LANG_FEATURES = { |
| | "ch": set("的一是不了人我在有他这为之大来以个中上们"), |
| | "en": set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"), |
| | "fr": set("àâäæçéèêëîïôœùûüÿÀÂÄÆÇÉÈÊËÎÏÔŒÙÛÜŸ"), |
| | "german": set("äöüßÄÖÜ"), |
| | "japan": set("あいうえおかきくけこさしすせそたちつてとなにぬねのはひふへほまみむめもやゆよらりるれろわをんアイウエオカキクケコサシスセソタチツテト") |
| | } |
| |
|
| | CONCURRENCY_LIMIT = 8 |
| |
|
| |
|
| | class PaddleOCRModelManager(object): |
| | def __init__(self, |
| | num_workers, |
| | model_factory): |
| | super().__init__() |
| | self._model_factory = model_factory |
| | self._queue = Queue() |
| | self._workers = [] |
| | self._model_initialized_event = Event() |
| | for _ in range(num_workers): |
| | worker = Thread(target=self._worker, daemon=False) |
| | worker.start() |
| | self._model_initialized_event.wait() |
| | self._model_initialized_event.clear() |
| | self._workers.append(worker) |
| |
|
| | def infer(self, *args, **kwargs): |
| | |
| | result_queue = Queue(maxsize=1) |
| | self._queue.put((args, kwargs, result_queue)) |
| | success, payload = result_queue.get() |
| | if success: |
| | return payload |
| | else: |
| | raise payload |
| |
|
| | def close(self): |
| | for _ in self._workers: |
| | self._queue.put(None) |
| | for worker in self._workers: |
| | worker.join() |
| |
|
| | def _worker(self): |
| | model = self._model_factory() |
| | self._model_initialized_event.set() |
| | while True: |
| | item = self._queue.get() |
| | if item is None: |
| | break |
| | args, kwargs, result_queue = item |
| | try: |
| | result = model.ocr(*args, **kwargs) |
| | result_queue.put((True, result)) |
| | except Exception as e: |
| | result_queue.put((False, e)) |
| | finally: |
| | self._queue.task_done() |
| |
|
| |
|
| | def create_model(lang): |
| | |
| | if lang == "ch": |
| | |
| | return PaddleOCR(lang=lang, use_angle_cls=True, use_gpu=False) |
| | else: |
| | return PaddleOCR(lang=lang, use_angle_cls=True, use_gpu=False) |
| |
|
| |
|
| | |
| | print("正在初始化多语言OCR模型...") |
| | model_managers = {} |
| | for lang, config in LANG_CONFIG.items(): |
| | print(f"加载 {LANG_MAP.get(lang, lang)} 模型...") |
| | model_manager = PaddleOCRModelManager(config["num_workers"], functools.partial(create_model, lang=lang)) |
| | model_managers[lang] = model_manager |
| | print("所有OCR模型加载完成") |
| |
|
| |
|
| | def close_model_managers(): |
| | for manager in model_managers.values(): |
| | manager.close() |
| |
|
| |
|
| | |
| | atexit.register(close_model_managers) |
| |
|
| |
|
| | def detect_language_by_features(text): |
| | """基于特征字符集检测语言""" |
| | if not text: |
| | return "en" |
| | |
| | |
| | lang_scores = {} |
| | for lang, char_set in LANG_FEATURES.items(): |
| | if not char_set: |
| | continue |
| | |
| | |
| | count = sum(1 for char in text if char in char_set) |
| | if count > 0: |
| | lang_scores[lang] = count / len(text) |
| | |
| | |
| | korean_count = sum(1 for char in text if '\uac00' <= char <= '\ud7a3') |
| | if korean_count > 0: |
| | lang_scores["korean"] = korean_count / len(text) |
| | |
| | |
| | if not lang_scores: |
| | return "en" |
| | |
| | |
| | return max(lang_scores.items(), key=lambda x: x[1])[0] |
| |
|
| |
|
| | def detect_language_with_fasttext(text): |
| | """使用fasttext检测语言""" |
| | if not text or not text.strip(): |
| | return "en" |
| | |
| | if lang_model is None: |
| | |
| | return detect_language_by_features(text) |
| | |
| | try: |
| | |
| | text = text[:1000] |
| | |
| | |
| | predictions = lang_model.predict(text.replace('\n', ' ')) |
| | lang_code = predictions[0][0].replace('__label__', '') |
| | |
| | |
| | paddle_lang = FASTTEXT_TO_PADDLE.get(lang_code, None) |
| | |
| | |
| | if paddle_lang is None: |
| | return detect_language_by_features(text) |
| | |
| | return paddle_lang |
| | except Exception as e: |
| | print(f"语言检测错误: {e}") |
| | |
| | return detect_language_by_features(text) |
| |
|
| |
|
| | def try_all_languages(image_path): |
| | """尝试所有语言的OCR,返回最佳结果""" |
| | best_result = None |
| | best_lang = "en" |
| | max_text_length = 0 |
| | |
| | |
| | for lang in LANG_CONFIG.keys(): |
| | try: |
| | ocr = model_managers[lang] |
| | result = ocr.infer(image_path, cls=True)[0] |
| | |
| | if result: |
| | |
| | all_text = " ".join([line[1][0] for line in result]) |
| | text_length = len(all_text.strip()) |
| | |
| | |
| | if text_length > max_text_length: |
| | max_text_length = text_length |
| | best_result = result |
| | best_lang = lang |
| | |
| | |
| | if lang == "ch" and text_length > 10: |
| | return result, lang |
| | except Exception as e: |
| | print(f"OCR处理错误 ({lang}): {e}") |
| | continue |
| | |
| | return best_result, best_lang |
| |
|
| |
|
| | def auto_detect_language(image_path): |
| | """使用多模型投票的方式检测语言""" |
| | |
| | languages_to_try = ["ch", "en"] |
| | results = {} |
| | detected_texts = {} |
| | |
| | for lang in languages_to_try: |
| | try: |
| | ocr = model_managers[lang] |
| | result = ocr.infer(image_path, cls=True)[0] |
| | if result: |
| | |
| | all_text = " ".join([line[1][0] for line in result]) |
| | detected_texts[lang] = all_text |
| | |
| | if all_text.strip(): |
| | |
| | detected = detect_language_with_fasttext(all_text) |
| | results[detected] = results.get(detected, 0) + 1 |
| | except Exception as e: |
| | print(f"OCR处理错误 ({lang}): {e}") |
| | continue |
| | |
| | |
| | if "ch" in detected_texts: |
| | chinese_chars = sum(1 for char in detected_texts["ch"] if '\u4e00' <= char <= '\u9fff') |
| | if chinese_chars > 5: |
| | return "ch" |
| | |
| | |
| | if not results: |
| | print("无法可靠检测语言,尝试所有语言...") |
| | _, best_lang = try_all_languages(image_path) |
| | return best_lang |
| | |
| | |
| | return max(results.items(), key=lambda x: x[1])[0] |
| |
|
| |
|
| | def save_base64_to_temp_file(base64_string): |
| | """将Base64图像保存为临时文件""" |
| | try: |
| | |
| | if "base64," in base64_string: |
| | base64_string = base64_string.split("base64,")[1] |
| | |
| | |
| | image_data = base64.b64decode(base64_string) |
| | |
| | |
| | temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
| | temp_file.write(image_data) |
| | temp_file.close() |
| | |
| | return temp_file.name |
| | except Exception as e: |
| | raise ValueError(f"处理Base64图像时出错: {str(e)}") |
| |
|
| |
|
| | def inference(img, return_text_only=True): |
| | """OCR推理函数,自动检测语言""" |
| | temp_file = None |
| | |
| | try: |
| | |
| | if isinstance(img, str): |
| | if img.startswith("data:") or re.match(r'^[A-Za-z0-9+/=]+$', img): |
| | |
| | temp_file = save_base64_to_temp_file(img) |
| | img_path = temp_file |
| | else: |
| | |
| | img_path = img |
| | else: |
| | |
| | img_path = img |
| | |
| | |
| | lang = auto_detect_language(img_path) |
| | print(f"检测到的语言: {LANG_MAP.get(lang, lang)}") |
| | |
| | |
| | ocr = model_managers[lang] |
| | result = ocr.infer(img_path, cls=True)[0] |
| | |
| | |
| | all_text = " ".join([line[1][0] for line in result]) |
| | if len(all_text.strip()) < 5: |
| | print("识别结果太少,尝试所有语言...") |
| | result, lang = try_all_languages(img_path) |
| | print(f"最佳语言: {LANG_MAP.get(lang, lang)}") |
| | |
| | |
| | boxes = [line[0] for line in result] |
| | txts = [line[1][0] for line in result] |
| | scores = [line[1][1] for line in result] |
| | |
| | |
| | pil_img = Image.open(img_path).convert("RGB") |
| | |
| | |
| | font_path = "./simfang.ttf" |
| | if not os.path.exists(font_path): |
| | |
| | possible_paths = [ |
| | "./doc/fonts/simfang.ttf", |
| | "/usr/local/lib/python3.10/site-packages/paddleocr/doc/fonts/simfang.ttf", |
| | "/usr/local/lib/python3.10/site-packages/paddleocr/ppocr/utils/fonts/simfang.ttf" |
| | ] |
| | |
| | for path in possible_paths: |
| | if os.path.exists(path): |
| | font_path = path |
| | break |
| | |
| | if return_text_only: |
| | |
| | return "\n".join(txts), LANG_MAP.get(lang, lang) |
| | else: |
| | |
| | try: |
| | im_show = draw_ocr(pil_img, boxes, txts, scores, font_path=font_path) |
| | return im_show, "\n".join(txts), LANG_MAP.get(lang, lang) |
| | except Exception as e: |
| | print(f"绘制OCR结果时出错: {e}") |
| | |
| | return pil_img, "\n".join(txts), LANG_MAP.get(lang, lang) |
| | |
| | finally: |
| | |
| | if temp_file and os.path.exists(temp_file): |
| | try: |
| | os.unlink(temp_file) |
| | except: |
| | pass |
| |
|
| |
|
| | def inference_with_image(img): |
| | """返回带标注的图像和文本""" |
| | im_show, text, lang = inference(img, return_text_only=False) |
| | return im_show, text, lang |
| |
|
| |
|
| | def inference_text_only(img): |
| | """仅返回文本""" |
| | text, lang = inference(img, return_text_only=True) |
| | return text, lang |
| |
|
| |
|
| | def inference_base64(base64_string): |
| | """处理Base64图像并返回OCR结果""" |
| | if not base64_string or base64_string.strip() == "": |
| | return "请提供有效的Base64图像字符串", "" |
| | |
| | try: |
| | text, lang = inference(base64_string, return_text_only=True) |
| | return text, lang |
| | except Exception as e: |
| | return f"处理Base64图像时出错: {str(e)}", "" |
| |
|
| |
|
| | title = '🔍 PaddleOCR 智能文字识别' |
| | description = ''' |
| | ### 功能特点 |
| | - 支持中文、英文、法语、德语、韩语和日语的智能文字识别 |
| | - 自动检测图像中的语言,无需手动选择 |
| | - 支持Base64编码图像识别 |
| | - 同时提供文本结果和标注图像 |
| | |
| | ### 使用方法 |
| | - 上传图像或提供Base64编码的图像数据 |
| | - 系统会自动检测语言并进行OCR识别 |
| | - 查看识别结果和标注图像 |
| | ''' |
| |
|
| | examples = [ |
| | ['en_example.jpg'], |
| | ['cn_example.jpg'], |
| | ['jp_example.jpg'], |
| | ] |
| |
|
| | |
| | css = """ |
| | .gradio-container { |
| | font-family: 'Roboto', 'Microsoft YaHei', sans-serif; |
| | } |
| | .output_image, .input_image { |
| | height: 30rem !important; |
| | width: 100% !important; |
| | object-fit: contain; |
| | border-radius: 8px; |
| | box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); |
| | } |
| | .tabs { |
| | margin-top: 0.5rem; |
| | } |
| | .output-text { |
| | font-family: 'Courier New', monospace; |
| | line-height: 1.5; |
| | padding: 1rem; |
| | border-radius: 8px; |
| | background-color: #f8f9fa; |
| | border: 1px solid #e9ecef; |
| | } |
| | .detected-lang { |
| | font-weight: bold; |
| | color: #4285f4; |
| | margin-bottom: 0.5rem; |
| | } |
| | """ |
| |
|
| | |
| | with gr.Blocks(title=title, css=css) as demo: |
| | gr.Markdown(f"# {title}") |
| | gr.Markdown(description) |
| | |
| | with gr.Tabs() as tabs: |
| | |
| | with gr.TabItem("图像上传识别"): |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | image_input = gr.Image(label="上传图像", type="filepath") |
| | image_submit = gr.Button("开始识别", variant="primary") |
| | |
| | with gr.Column(scale=2): |
| | with gr.Row(): |
| | image_output = gr.Image(label="标注结果", type="pil") |
| | with gr.Row(): |
| | detected_lang = gr.Textbox(label="检测到的语言", lines=1) |
| | with gr.Row(): |
| | text_output = gr.Textbox(label="识别文本", lines=10, elem_classes=["output-text"]) |
| | |
| | |
| | with gr.TabItem("Base64图像识别"): |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | base64_input = gr.Textbox( |
| | label="输入Base64编码的图像数据", |
| | lines=8, |
| | placeholder="在此粘贴Base64编码的图像数据..." |
| | ) |
| | base64_submit = gr.Button("开始识别", variant="primary") |
| | |
| | with gr.Column(scale=2): |
| | base64_lang = gr.Textbox(label="检测到的语言", lines=1) |
| | base64_output = gr.Textbox( |
| | label="识别文本", |
| | lines=15, |
| | elem_classes=["output-text"] |
| | ) |
| | |
| | |
| | with gr.Accordion("API使用说明", open=False): |
| | gr.Markdown(""" |
| | ## API使用方法 |
| | |
| | ### 1. 图像上传API |
| | |
| | ```bash |
| | curl -X POST "http://localhost:7860/api/predict" \\ |
| | -F "fn_index=0" \\ |
| | -F "data=@/path/to/your/image.jpg" |
| | ``` |
| | |
| | ### 2. Base64图像API |
| | |
| | ```bash |
| | curl -X POST "http://localhost:7860/api/predict" \\ |
| | -H "Content-Type: application/json" \\ |
| | -d '{ |
| | "fn_index": 1, |
| | "data": ["YOUR_BASE64_STRING_HERE"] |
| | }' |
| | ``` |
| | """) |
| | |
| | |
| | image_submit.click( |
| | fn=inference_with_image, |
| | inputs=[image_input], |
| | outputs=[image_output, text_output, detected_lang] |
| | ) |
| | |
| | base64_submit.click( |
| | fn=inference_base64, |
| | inputs=[base64_input], |
| | outputs=[base64_output, base64_lang] |
| | ) |
| |
|
| | |
| | demo.launch(debug=False, share=False) |