logistics_ocr / app.py
mlbench123's picture
Update app.py
e277539 verified
import gradio as gr
import json
import os
from pathlib import Path
from typing import List, Dict, Any, Optional
import traceback
from PIL import Image
import PyPDF2
import pytesseract
from pdf2image import convert_from_path
from huggingface_hub import InferenceClient
# ==============================================================
# Extraction prompt
# ==============================================================
EXTRACTION_PROMPT = """You are an expert shipping-document data extractor.
You will be given OCR/text extracted from shipping documents.
Extract and return ONLY valid JSON matching this schema:
{
"poNumber": string | null,
"shipFrom": string | null,
"carrierType": string | null,
"originCarrier": string | null,
"railCarNumber": string | null,
"totalQuantity": number | null,
"totalUnits": string | null,
"attachments": [string],
"accountName": string | null,
"inventories": {
"items": [
{
"quantityShipped": number | null,
"inventoryUnits": string | null,
"pcs": number | null,
"productName": string | null,
"productCode": string | null,
"product": {
"category": number | null,
"defaultUnits": string | null,
"unit": string | null,
"pcs": number | null,
"mbf": number | null,
"sf": number | null,
"pcsHeight": number | null,
"pcsWidth": number | null,
"pcsLength": number | null
},
"customFields": [string]
}
]
}
}
Return ONLY JSON. No explanation.
"""
# ==============================================================
# JSON Helpers
# ==============================================================
def extract_json(text: str) -> Dict:
text = text.strip()
if text.startswith("```"):
text = text.split("\n", 1)[-1]
text = text.replace("```", "").strip()
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1:
raise json.JSONDecodeError("No JSON found", text, 0)
return json.loads(text[start:end+1])
# ==============================================================
# OCR + TEXT EXTRACTION
# ==============================================================
def extract_text_from_pdf(pdf_path: str) -> str:
try:
with open(pdf_path, "rb") as f:
reader = PyPDF2.PdfReader(f)
text = ""
for page in reader.pages:
t = page.extract_text()
if t:
text += t + "\n"
return text
except Exception as e:
return f"PDF text error: {e}"
def ocr_image(img: Image.Image) -> str:
if img.mode != "RGB":
img = img.convert("RGB")
return pytesseract.image_to_string(img)
def extract_pdf_with_ocr(pdf_path: str) -> str:
text = extract_text_from_pdf(pdf_path)
if text and len(text) > 50:
return text
pages = convert_from_path(pdf_path, dpi=250)
ocr_text = ""
for p in pages:
ocr_text += ocr_image(p) + "\n"
return ocr_text
def process_files(files: List[str]) -> Dict[str, Any]:
result = {
"text_content": "",
"attachments": []
}
for f in files:
name = Path(f).name
ext = Path(f).suffix.lower()
result["attachments"].append(name)
if ext == ".pdf":
text = extract_pdf_with_ocr(f)
elif ext in [".jpg", ".jpeg", ".png", ".webp"]:
img = Image.open(f)
text = ocr_image(img)
elif ext in [".txt", ".csv"]:
text = open(f, encoding="utf-8", errors="ignore").read()
elif ext in [".doc", ".docx"]:
import docx
doc = docx.Document(f)
text = "\n".join([p.text for p in doc.paragraphs])
else:
text = ""
result["text_content"] += f"\n\n=== {name} ===\n{text}"
return result
# ==============================================================
# HF MODEL CALL (Robust: conversational support)
# ==============================================================
def extract_with_hf(processed_data: Dict[str, Any]) -> Dict[str, Any]:
hf_token = os.getenv("HF_TOKEN")
model = os.getenv("HF_MODEL", "mistralai/Mistral-7B-Instruct-v0.3")
client = InferenceClient(model=model, token=hf_token)
prompt = (
EXTRACTION_PROMPT
+ "\n\nDOCUMENT TEXT:\n"
+ processed_data["text_content"]
+ "\n\nATTACHMENTS:\n"
+ json.dumps(processed_data["attachments"])
)
raw = ""
try:
# FIRST: try conversational (works for Mistral)
conv = client.conversational(
{
"past_user_inputs": [],
"generated_responses": [],
"text": prompt,
}
)
raw = conv["generated_text"]
except Exception as e1:
try:
# fallback to chat
resp = client.chat_completion(
messages=[
{"role": "system", "content": "Return strict JSON only."},
{"role": "user", "content": prompt}
],
temperature=0.1,
max_tokens=3000
)
raw = resp.choices[0].message.content
except Exception as e2:
return {
"success": False,
"error": f"Model call failed:\n{e1}\n\n{e2}",
"traceback": traceback.format_exc()
}
try:
parsed = extract_json(raw)
return {
"success": True,
"data": parsed,
"raw": raw
}
except Exception as je:
return {
"success": False,
"error": f"JSON parse error: {je}",
"raw": raw
}
# ==============================================================
# MAIN PROCESS
# ==============================================================
def process_documents(files):
if not files:
return "❌ Upload file", "{}", ""
paths = [f.name if hasattr(f, "name") else f for f in files]
status = "πŸ“„ Extracting text...\n"
processed = process_files(paths)
status += "πŸ€– Calling HF model...\n"
result = extract_with_hf(processed)
if result["success"]:
json_out = json.dumps(result["data"], indent=2)
return "βœ… Success", json_out, json_out
return f"❌ Extraction failed:\n{result['error']}", "{}", result.get("raw", "")
# ==============================================================
# UI
# ==============================================================
with gr.Blocks() as demo:
gr.Markdown("# πŸ“„ Logistic OCR – Open Source Version")
file_input = gr.File(file_count="multiple")
btn = gr.Button("πŸš€ Extract")
status = gr.Textbox(label="Status")
json_out = gr.Code(language="json")
preview = gr.Textbox(label="Preview")
btn.click(
process_documents,
inputs=file_input,
outputs=[status, json_out, preview]
)
demo.launch(server_name="0.0.0.0", server_port=7860)