File size: 7,230 Bytes
caaf797
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import torch
import re
import os
import textract
from fpdf import FPDF
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

# --- Configuration ---
# All paths are now local
INPUT_DOC_PATH = "Doreen.doc"
OUTPUT_PDF_PATH = "Doreen_DeFio_Report_Local_Test.pdf"

# --- Model Paths (loading from local Hugging Face cache) ---
GENDER_MODEL_PATH = "google/gemma-3-270m-qat-q4_0-unquantized"
BASE_MODEL_PATH = "unsloth/gemma-2b-it"
# FIX: This now points to the local folder containing your fine-tuned model.
LORA_ADAPTER_PATH = "gemma-grammar-lora"

# --- Global variables for models ---
grammar_model = None
grammar_tokenizer = None
gender_model = None
gender_tokenizer = None
device = "cpu"

# --- 1. Model Loading Logic (from main.py) ---
def load_all_models():
    """Loads all AI models into memory."""
    global grammar_model, grammar_tokenizer, gender_model, gender_tokenizer
    print("--- Starting Model Loading ---")
    try:
        print(f"Loading gender model from cache: {GENDER_MODEL_PATH}")
        gender_tokenizer = AutoTokenizer.from_pretrained(GENDER_MODEL_PATH)
        gender_model = AutoModelForCausalLM.from_pretrained(GENDER_MODEL_PATH).to(device)
        print("βœ… Gender verifier model loaded successfully!")

        print(f"Loading base model for grammar correction from cache: {BASE_MODEL_PATH}")
        base_model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL_PATH, dtype=torch.float32
        ).to(device)
        grammar_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)

        print(f"Applying LoRA adapter from local folder: {LORA_ADAPTER_PATH}")
        grammar_model = PeftModel.from_pretrained(base_model, LORA_ADAPTER_PATH).to(device)
        print("βœ… Grammar correction model loaded successfully!")

        if grammar_tokenizer.pad_token is None:
            grammar_tokenizer.pad_token = grammar_tokenizer.eos_token
        if gender_tokenizer.pad_token is None:
            gender_tokenizer.pad_token = gender_tokenizer.eos_token

    except Exception as e:
        print(f"❌ Critical error during model loading: {e}")
        return False
    
    print("--- Model Loading Complete ---")
    return True

# --- 2. Correction Functions (adapted from main.py) ---
def run_grammar_correction(text: str) -> str:
    """Corrects grammar using the loaded LoRA model."""
    if not grammar_model: return text
    input_text = f"Prompt: {text}\nResponse:"
    inputs = grammar_tokenizer(input_text, return_tensors="pt").to(device)
    output_ids = grammar_model.generate(**inputs, max_new_tokens=64, do_sample=False)
    output_text = grammar_tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    # Cleaning logic
    if "Response:" in output_text:
        parts = output_text.split("Response:")
        if len(parts) > 1: return parts[1].strip()
    return output_text.strip()

def run_gender_correction(text: str) -> str:
    """Corrects gender using the loaded gender model and regex."""
    if not gender_model: return text
    input_text = f"Prompt: Please rewrite the sentence with correct grammar and gender. Output ONLY the corrected sentence:\n{text}\nResponse:"
    inputs = gender_tokenizer(input_text, return_tensors="pt").to(device)
    output_ids = gender_model.generate(
        **inputs, max_new_tokens=64, temperature=0.0,
        do_sample=False, eos_token_id=gender_tokenizer.eos_token_id
    )
    output_text = gender_tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    # Cleaning logic
    if "Response:" in output_text:
        parts = output_text.split("Response:")
        if len(parts) > 1: output_text = parts[1].strip()
    cleaned_text = re.sub(r'^(Corrected sentence:|Correct:|Prompt:)\s*', '', output_text, flags=re.IGNORECASE).strip().strip('"')

    # Regex safety net
    corrections = {
        r'\bher wife\b': 'her husband', r'\bhis husband\b': 'his wife',
        r'\bhe is a girl\b': 'he is a boy', r'\bshe is a boy\b': 'she is a girl'
    }
    for pattern, replacement in corrections.items():
        cleaned_text = re.sub(pattern, replacement, cleaned_text, flags=re.IGNORECASE)
    return cleaned_text

# --- 3. Document Processing Logic (from document_pipeline.py) ---
def extract_text_from_doc(filepath):
    """Extracts all text using textract."""
    try:
        text_bytes = textract.process(filepath)
        return text_bytes.decode('utf-8')
    except Exception as e:
        print(f"Error reading document with textract: {e}")
        return None

def parse_and_correct_text(raw_text):
    """Parses text and calls the local correction functions."""
    structured_data = {}
    key_value_pattern = re.compile(r'^\s*(Client Name|Date of Exam|...):s*(.*)', re.IGNORECASE | re.DOTALL) # Abridged for brevity
    
    # This is the key change: we call the local functions directly
    # instead of making API requests.
    for line in raw_text.split('\n'):
        # ... (parsing logic) ...
        # Example of calling the function directly:
        # corrected_value = run_grammar_correction(value)
        # final_corrected = run_gender_correction(grammar_corrected)
        pass # Placeholder for the full parsing logic from your script
        
    # Dummy data to demonstrate PDF generation
    structured_data['Client Name'] = run_grammar_correction("Morgan & Morgan")
    structured_data['Intake'] = run_gender_correction(run_grammar_correction("The IME physician asked the examinee if he has any issues sleeping. The examinee replied yes."))

    return structured_data

class PDF(FPDF):
    """Custom PDF class with Unicode font support."""
    def header(self):
        self.add_font('DejaVu', 'B', 'DejaVuSans-Bold.ttf', uni=True)
        self.set_font('DejaVu', 'B', 15)
        self.cell(0, 10, 'IME WatchDog Report', 0, 1, 'C')
        self.ln(10)

    def footer(self):
        self.set_y(-15)
        self.set_font('Helvetica', 'I', 8)
        self.cell(0, 10, f'Page {self.page_no()}', 0, 0, 'C')

def generate_pdf(data, output_path):
    """Generates the final PDF report."""
    pdf = PDF()
    pdf.add_font('DejaVu', '', 'DejaVuSans.ttf', uni=True)
    pdf.add_page()
    pdf.set_font('DejaVu', '', 12)

    for key, value in data.items():
        pdf.set_font('DejaVu', 'B', 12)
        pdf.multi_cell(0, 8, f"{key}:")
        pdf.set_font('DejaVu', '', 12)
        pdf.multi_cell(0, 8, str(value))
        pdf.ln(4)

    pdf.output(output_path)
    print(f"βœ… Successfully generated PDF report at: {output_path}")

# --- Main Execution ---
if __name__ == "__main__":
    print("--- Starting Local Test Pipeline ---")
    
    # 1. Pre-requisite: Make sure models are downloaded.
    # It's assumed you've run download_models.py script locally first.
    
    # 2. Load the models into memory
    if load_all_models():
        # 3. Extract raw text from the input document
        raw_text = extract_text_from_doc(INPUT_DOC_PATH)
        if raw_text:
            # 4. Parse and correct the text
            corrected_data = parse_and_correct_text(raw_text)
            
            # 5. Generate the final PDF report
            generate_pdf(corrected_data, OUTPUT_PDF_PATH)
            
    print("--- Pipeline Finished ---")