Spaces:
Build error
Build error
| import torch | |
| import re | |
| import os | |
| import textract | |
| from fpdf import FPDF | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| # --- Configuration --- | |
| # All paths are now local | |
| INPUT_DOC_PATH = "Doreen.doc" | |
| OUTPUT_PDF_PATH = "Doreen_DeFio_Report_Local_Test.pdf" | |
| # --- Model Paths (loading from local Hugging Face cache) --- | |
| GENDER_MODEL_PATH = "google/gemma-3-270m-qat-q4_0-unquantized" | |
| BASE_MODEL_PATH = "unsloth/gemma-2b-it" | |
| # FIX: This now points to the local folder containing your fine-tuned model. | |
| LORA_ADAPTER_PATH = "gemma-grammar-lora" | |
| # --- Global variables for models --- | |
| grammar_model = None | |
| grammar_tokenizer = None | |
| gender_model = None | |
| gender_tokenizer = None | |
| device = "cpu" | |
| # --- 1. Model Loading Logic (from main.py) --- | |
| def load_all_models(): | |
| """Loads all AI models into memory.""" | |
| global grammar_model, grammar_tokenizer, gender_model, gender_tokenizer | |
| print("--- Starting Model Loading ---") | |
| try: | |
| print(f"Loading gender model from cache: {GENDER_MODEL_PATH}") | |
| gender_tokenizer = AutoTokenizer.from_pretrained(GENDER_MODEL_PATH) | |
| gender_model = AutoModelForCausalLM.from_pretrained(GENDER_MODEL_PATH).to(device) | |
| print("β Gender verifier model loaded successfully!") | |
| print(f"Loading base model for grammar correction from cache: {BASE_MODEL_PATH}") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL_PATH, dtype=torch.float32 | |
| ).to(device) | |
| grammar_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH) | |
| print(f"Applying LoRA adapter from local folder: {LORA_ADAPTER_PATH}") | |
| grammar_model = PeftModel.from_pretrained(base_model, LORA_ADAPTER_PATH).to(device) | |
| print("β Grammar correction model loaded successfully!") | |
| if grammar_tokenizer.pad_token is None: | |
| grammar_tokenizer.pad_token = grammar_tokenizer.eos_token | |
| if gender_tokenizer.pad_token is None: | |
| gender_tokenizer.pad_token = gender_tokenizer.eos_token | |
| except Exception as e: | |
| print(f"β Critical error during model loading: {e}") | |
| return False | |
| print("--- Model Loading Complete ---") | |
| return True | |
| # --- 2. Correction Functions (adapted from main.py) --- | |
| def run_grammar_correction(text: str) -> str: | |
| """Corrects grammar using the loaded LoRA model.""" | |
| if not grammar_model: return text | |
| input_text = f"Prompt: {text}\nResponse:" | |
| inputs = grammar_tokenizer(input_text, return_tensors="pt").to(device) | |
| output_ids = grammar_model.generate(**inputs, max_new_tokens=64, do_sample=False) | |
| output_text = grammar_tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| # Cleaning logic | |
| if "Response:" in output_text: | |
| parts = output_text.split("Response:") | |
| if len(parts) > 1: return parts[1].strip() | |
| return output_text.strip() | |
| def run_gender_correction(text: str) -> str: | |
| """Corrects gender using the loaded gender model and regex.""" | |
| if not gender_model: return text | |
| input_text = f"Prompt: Please rewrite the sentence with correct grammar and gender. Output ONLY the corrected sentence:\n{text}\nResponse:" | |
| inputs = gender_tokenizer(input_text, return_tensors="pt").to(device) | |
| output_ids = gender_model.generate( | |
| **inputs, max_new_tokens=64, temperature=0.0, | |
| do_sample=False, eos_token_id=gender_tokenizer.eos_token_id | |
| ) | |
| output_text = gender_tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| # Cleaning logic | |
| if "Response:" in output_text: | |
| parts = output_text.split("Response:") | |
| if len(parts) > 1: output_text = parts[1].strip() | |
| cleaned_text = re.sub(r'^(Corrected sentence:|Correct:|Prompt:)\s*', '', output_text, flags=re.IGNORECASE).strip().strip('"') | |
| # Regex safety net | |
| corrections = { | |
| r'\bher wife\b': 'her husband', r'\bhis husband\b': 'his wife', | |
| r'\bhe is a girl\b': 'he is a boy', r'\bshe is a boy\b': 'she is a girl' | |
| } | |
| for pattern, replacement in corrections.items(): | |
| cleaned_text = re.sub(pattern, replacement, cleaned_text, flags=re.IGNORECASE) | |
| return cleaned_text | |
| # --- 3. Document Processing Logic (from document_pipeline.py) --- | |
| def extract_text_from_doc(filepath): | |
| """Extracts all text using textract.""" | |
| try: | |
| text_bytes = textract.process(filepath) | |
| return text_bytes.decode('utf-8') | |
| except Exception as e: | |
| print(f"Error reading document with textract: {e}") | |
| return None | |
| def parse_and_correct_text(raw_text): | |
| """Parses text and calls the local correction functions.""" | |
| structured_data = {} | |
| key_value_pattern = re.compile(r'^\s*(Client Name|Date of Exam|...):s*(.*)', re.IGNORECASE | re.DOTALL) # Abridged for brevity | |
| # This is the key change: we call the local functions directly | |
| # instead of making API requests. | |
| for line in raw_text.split('\n'): | |
| # ... (parsing logic) ... | |
| # Example of calling the function directly: | |
| # corrected_value = run_grammar_correction(value) | |
| # final_corrected = run_gender_correction(grammar_corrected) | |
| pass # Placeholder for the full parsing logic from your script | |
| # Dummy data to demonstrate PDF generation | |
| structured_data['Client Name'] = run_grammar_correction("Morgan & Morgan") | |
| structured_data['Intake'] = run_gender_correction(run_grammar_correction("The IME physician asked the examinee if he has any issues sleeping. The examinee replied yes.")) | |
| return structured_data | |
| class PDF(FPDF): | |
| """Custom PDF class with Unicode font support.""" | |
| def header(self): | |
| self.add_font('DejaVu', 'B', 'DejaVuSans-Bold.ttf', uni=True) | |
| self.set_font('DejaVu', 'B', 15) | |
| self.cell(0, 10, 'IME WatchDog Report', 0, 1, 'C') | |
| self.ln(10) | |
| def footer(self): | |
| self.set_y(-15) | |
| self.set_font('Helvetica', 'I', 8) | |
| self.cell(0, 10, f'Page {self.page_no()}', 0, 0, 'C') | |
| def generate_pdf(data, output_path): | |
| """Generates the final PDF report.""" | |
| pdf = PDF() | |
| pdf.add_font('DejaVu', '', 'DejaVuSans.ttf', uni=True) | |
| pdf.add_page() | |
| pdf.set_font('DejaVu', '', 12) | |
| for key, value in data.items(): | |
| pdf.set_font('DejaVu', 'B', 12) | |
| pdf.multi_cell(0, 8, f"{key}:") | |
| pdf.set_font('DejaVu', '', 12) | |
| pdf.multi_cell(0, 8, str(value)) | |
| pdf.ln(4) | |
| pdf.output(output_path) | |
| print(f"β Successfully generated PDF report at: {output_path}") | |
| # --- Main Execution --- | |
| if __name__ == "__main__": | |
| print("--- Starting Local Test Pipeline ---") | |
| # 1. Pre-requisite: Make sure models are downloaded. | |
| # It's assumed you've run download_models.py script locally first. | |
| # 2. Load the models into memory | |
| if load_all_models(): | |
| # 3. Extract raw text from the input document | |
| raw_text = extract_text_from_doc(INPUT_DOC_PATH) | |
| if raw_text: | |
| # 4. Parse and correct the text | |
| corrected_data = parse_and_correct_text(raw_text) | |
| # 5. Generate the final PDF report | |
| generate_pdf(corrected_data, OUTPUT_PDF_PATH) | |
| print("--- Pipeline Finished ---") | |