Spaces:
Build error
Build error
File size: 7,230 Bytes
caaf797 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import torch
import re
import os
import textract
from fpdf import FPDF
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# --- Configuration ---
# All paths are now local
INPUT_DOC_PATH = "Doreen.doc"
OUTPUT_PDF_PATH = "Doreen_DeFio_Report_Local_Test.pdf"
# --- Model Paths (loading from local Hugging Face cache) ---
GENDER_MODEL_PATH = "google/gemma-3-270m-qat-q4_0-unquantized"
BASE_MODEL_PATH = "unsloth/gemma-2b-it"
# FIX: This now points to the local folder containing your fine-tuned model.
LORA_ADAPTER_PATH = "gemma-grammar-lora"
# --- Global variables for models ---
grammar_model = None
grammar_tokenizer = None
gender_model = None
gender_tokenizer = None
device = "cpu"
# --- 1. Model Loading Logic (from main.py) ---
def load_all_models():
"""Loads all AI models into memory."""
global grammar_model, grammar_tokenizer, gender_model, gender_tokenizer
print("--- Starting Model Loading ---")
try:
print(f"Loading gender model from cache: {GENDER_MODEL_PATH}")
gender_tokenizer = AutoTokenizer.from_pretrained(GENDER_MODEL_PATH)
gender_model = AutoModelForCausalLM.from_pretrained(GENDER_MODEL_PATH).to(device)
print("β
Gender verifier model loaded successfully!")
print(f"Loading base model for grammar correction from cache: {BASE_MODEL_PATH}")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_PATH, dtype=torch.float32
).to(device)
grammar_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
print(f"Applying LoRA adapter from local folder: {LORA_ADAPTER_PATH}")
grammar_model = PeftModel.from_pretrained(base_model, LORA_ADAPTER_PATH).to(device)
print("β
Grammar correction model loaded successfully!")
if grammar_tokenizer.pad_token is None:
grammar_tokenizer.pad_token = grammar_tokenizer.eos_token
if gender_tokenizer.pad_token is None:
gender_tokenizer.pad_token = gender_tokenizer.eos_token
except Exception as e:
print(f"β Critical error during model loading: {e}")
return False
print("--- Model Loading Complete ---")
return True
# --- 2. Correction Functions (adapted from main.py) ---
def run_grammar_correction(text: str) -> str:
"""Corrects grammar using the loaded LoRA model."""
if not grammar_model: return text
input_text = f"Prompt: {text}\nResponse:"
inputs = grammar_tokenizer(input_text, return_tensors="pt").to(device)
output_ids = grammar_model.generate(**inputs, max_new_tokens=64, do_sample=False)
output_text = grammar_tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Cleaning logic
if "Response:" in output_text:
parts = output_text.split("Response:")
if len(parts) > 1: return parts[1].strip()
return output_text.strip()
def run_gender_correction(text: str) -> str:
"""Corrects gender using the loaded gender model and regex."""
if not gender_model: return text
input_text = f"Prompt: Please rewrite the sentence with correct grammar and gender. Output ONLY the corrected sentence:\n{text}\nResponse:"
inputs = gender_tokenizer(input_text, return_tensors="pt").to(device)
output_ids = gender_model.generate(
**inputs, max_new_tokens=64, temperature=0.0,
do_sample=False, eos_token_id=gender_tokenizer.eos_token_id
)
output_text = gender_tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Cleaning logic
if "Response:" in output_text:
parts = output_text.split("Response:")
if len(parts) > 1: output_text = parts[1].strip()
cleaned_text = re.sub(r'^(Corrected sentence:|Correct:|Prompt:)\s*', '', output_text, flags=re.IGNORECASE).strip().strip('"')
# Regex safety net
corrections = {
r'\bher wife\b': 'her husband', r'\bhis husband\b': 'his wife',
r'\bhe is a girl\b': 'he is a boy', r'\bshe is a boy\b': 'she is a girl'
}
for pattern, replacement in corrections.items():
cleaned_text = re.sub(pattern, replacement, cleaned_text, flags=re.IGNORECASE)
return cleaned_text
# --- 3. Document Processing Logic (from document_pipeline.py) ---
def extract_text_from_doc(filepath):
"""Extracts all text using textract."""
try:
text_bytes = textract.process(filepath)
return text_bytes.decode('utf-8')
except Exception as e:
print(f"Error reading document with textract: {e}")
return None
def parse_and_correct_text(raw_text):
"""Parses text and calls the local correction functions."""
structured_data = {}
key_value_pattern = re.compile(r'^\s*(Client Name|Date of Exam|...):s*(.*)', re.IGNORECASE | re.DOTALL) # Abridged for brevity
# This is the key change: we call the local functions directly
# instead of making API requests.
for line in raw_text.split('\n'):
# ... (parsing logic) ...
# Example of calling the function directly:
# corrected_value = run_grammar_correction(value)
# final_corrected = run_gender_correction(grammar_corrected)
pass # Placeholder for the full parsing logic from your script
# Dummy data to demonstrate PDF generation
structured_data['Client Name'] = run_grammar_correction("Morgan & Morgan")
structured_data['Intake'] = run_gender_correction(run_grammar_correction("The IME physician asked the examinee if he has any issues sleeping. The examinee replied yes."))
return structured_data
class PDF(FPDF):
"""Custom PDF class with Unicode font support."""
def header(self):
self.add_font('DejaVu', 'B', 'DejaVuSans-Bold.ttf', uni=True)
self.set_font('DejaVu', 'B', 15)
self.cell(0, 10, 'IME WatchDog Report', 0, 1, 'C')
self.ln(10)
def footer(self):
self.set_y(-15)
self.set_font('Helvetica', 'I', 8)
self.cell(0, 10, f'Page {self.page_no()}', 0, 0, 'C')
def generate_pdf(data, output_path):
"""Generates the final PDF report."""
pdf = PDF()
pdf.add_font('DejaVu', '', 'DejaVuSans.ttf', uni=True)
pdf.add_page()
pdf.set_font('DejaVu', '', 12)
for key, value in data.items():
pdf.set_font('DejaVu', 'B', 12)
pdf.multi_cell(0, 8, f"{key}:")
pdf.set_font('DejaVu', '', 12)
pdf.multi_cell(0, 8, str(value))
pdf.ln(4)
pdf.output(output_path)
print(f"β
Successfully generated PDF report at: {output_path}")
# --- Main Execution ---
if __name__ == "__main__":
print("--- Starting Local Test Pipeline ---")
# 1. Pre-requisite: Make sure models are downloaded.
# It's assumed you've run download_models.py script locally first.
# 2. Load the models into memory
if load_all_models():
# 3. Extract raw text from the input document
raw_text = extract_text_from_doc(INPUT_DOC_PATH)
if raw_text:
# 4. Parse and correct the text
corrected_data = parse_and_correct_text(raw_text)
# 5. Generate the final PDF report
generate_pdf(corrected_data, OUTPUT_PDF_PATH)
print("--- Pipeline Finished ---")
|