Enoch Jason J commited on
Commit
caaf797
Β·
1 Parent(s): 1ab6f41

Modified app.py

Browse files
Files changed (2) hide show
  1. app.py +39 -64
  2. local_test.py +180 -0
app.py CHANGED
@@ -1,18 +1,14 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
  import torch
4
  import re
5
  import os
6
-
7
- # --- Import Libraries ---
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
- from peft import PeftModel
 
10
 
11
- # --- Model Paths ---
12
- GENDER_MODEL_PATH = "google/gemma-3-270m-qat-q4_0-unquantized"
13
- BASE_MODEL_PATH = "unsloth/gemma-2b-it"
14
- # This correctly points to your model on the Hugging Face Hub.
15
  LORA_ADAPTER_PATH = "enoch10jason/gemma-grammar-lora"
 
16
 
17
  # --- Global variables for models ---
18
  grammar_model = None
@@ -21,32 +17,25 @@ gender_model = None
21
  gender_tokenizer = None
22
  device = "cpu"
23
 
24
- print("--- Starting Model Loading ---")
25
 
26
  try:
27
- # Models are loaded from the pre-downloaded cache in the image.
28
- # No token is needed at runtime because the files are already cached.
29
- print(f"Loading gender model from cache: {GENDER_MODEL_PATH}")
 
 
 
 
 
 
 
 
 
30
  gender_tokenizer = AutoTokenizer.from_pretrained(GENDER_MODEL_PATH)
31
  gender_model = AutoModelForCausalLM.from_pretrained(GENDER_MODEL_PATH).to(device)
32
  print("βœ… Gender verifier model loaded successfully!")
33
 
34
- print(f"Loading base model for grammar correction from cache: {BASE_MODEL_PATH}")
35
- base_model = AutoModelForCausalLM.from_pretrained(
36
- BASE_MODEL_PATH,
37
- dtype=torch.float32,
38
- ).to(device)
39
- grammar_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
40
-
41
- print(f"Applying LoRA adapter from cache: {LORA_ADAPTER_PATH}")
42
- grammar_model = PeftModel.from_pretrained(base_model, LORA_ADAPTER_PATH).to(device)
43
- print("βœ… Grammar correction model loaded successfully!")
44
-
45
- if grammar_tokenizer.pad_token is None:
46
- grammar_tokenizer.pad_token = grammar_tokenizer.eos_token
47
- if gender_tokenizer.pad_token is None:
48
- gender_tokenizer.pad_token = gender_tokenizer.eos_token
49
-
50
  except Exception as e:
51
  print(f"❌ Critical error during model loading: {e}")
52
  grammar_model = None
@@ -54,7 +43,6 @@ except Exception as e:
54
 
55
  print("--- Model Loading Complete ---")
56
 
57
-
58
  # --- FastAPI Application Setup ---
59
  app = FastAPI(title="Text Correction API")
60
 
@@ -65,57 +53,44 @@ class CorrectionResponse(BaseModel):
65
  original_text: str
66
  corrected_text: str
67
 
68
- # --- Helper Functions ---
69
- def clean_grammar_response(text: str) -> str:
70
- if "Response:" in text:
71
- parts = text.split("Response:")
72
- if len(parts) > 1: return parts[1].strip()
73
- return text.strip()
74
-
75
- def clean_gender_response(text: str) -> str:
76
- if "Response:" in text:
77
- parts = text.split("Response:")
78
- if len(parts) > 1: text = parts[1].strip()
79
- text = re.sub(r'^(Corrected sentence:|Correct:|Prompt:)\s*', '', text, flags=re.IGNORECASE)
80
- return text.strip().strip('"')
81
-
82
- def correct_gender_rules(text: str) -> str:
83
- corrections = {
84
- r'\bher wife\b': 'her husband', r'\bhis husband\b': 'his wife',
85
- r'\bhe is a girl\b': 'he is a boy', r'\bshe is a boy\b': 'she is a girl'
86
- }
87
- for pattern, replacement in corrections.items():
88
- text = re.sub(pattern, replacement, text, flags=re.IGNORECASE)
89
- return text
90
-
91
  # --- API Endpoints ---
92
  @app.post("/correct_grammar", response_model=CorrectionResponse)
93
  async def handle_grammar_correction(request: CorrectionRequest):
94
- if not grammar_model or not grammar_tokenizer:
95
  raise HTTPException(status_code=503, detail="Grammar model is not available.")
 
96
  prompt_text = request.text
97
  input_text = f"Prompt: {prompt_text}\nResponse:"
98
  inputs = grammar_tokenizer(input_text, return_tensors="pt").to(device)
99
- output_ids = grammar_model.generate(**inputs, max_new_tokens=64, do_sample=False)
 
100
  output_text = grammar_tokenizer.decode(output_ids[0], skip_special_tokens=True)
101
- corrected = clean_grammar_response(output_text)
 
102
  return CorrectionResponse(original_text=prompt_text, corrected_text=corrected)
103
 
104
  @app.post("/correct_gender", response_model=CorrectionResponse)
105
  async def handle_gender_correction(request: CorrectionRequest):
106
- if not gender_model or not gender_tokenizer:
107
  raise HTTPException(status_code=503, detail="Gender model is not available.")
 
108
  prompt_text = request.text
109
  input_text = f"Prompt: Please rewrite the sentence with correct grammar and gender. Output ONLY the corrected sentence:\n{prompt_text}\nResponse:"
110
  inputs = gender_tokenizer(input_text, return_tensors="pt").to(device)
111
- output_ids = gender_model.generate(
112
- **inputs, max_new_tokens=64, temperature=0.0,
113
- do_sample=False, eos_token_id=gender_tokenizer.eos_token_id
114
- )
115
  output_text = gender_tokenizer.decode(output_ids[0], skip_special_tokens=True)
116
- cleaned_from_model = clean_gender_response(output_text)
117
- final_correction = correct_gender_rules(cleaned_from_model)
118
- return CorrectionResponse(original_text=prompt_text, corrected_text=final_correction)
 
 
 
 
 
 
 
 
 
119
 
120
  @app.get("/")
121
  def read_root():
 
 
 
1
  import torch
2
  import re
3
  import os
4
+ from unsloth import FastLanguageModel
 
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from fastapi import FastAPI, HTTPException
7
+ from pydantic import BaseModel
8
 
9
+ # --- Model Paths (These are identifiers for the cached models) ---
 
 
 
10
  LORA_ADAPTER_PATH = "enoch10jason/gemma-grammar-lora"
11
+ GENDER_MODEL_PATH = "google/gemma-3-270m-qat-q4_0-unquantized"
12
 
13
  # --- Global variables for models ---
14
  grammar_model = None
 
17
  gender_tokenizer = None
18
  device = "cpu"
19
 
20
+ print("--- Starting Model Loading From Cache ---")
21
 
22
  try:
23
+ # 1. Load your fine-tuned model using Unsloth
24
+ # This correctly loads the model and applies the adapter.
25
+ print(f"Loading grammar model and adapter: {LORA_ADAPTER_PATH}")
26
+ grammar_model, grammar_tokenizer = FastLanguageModel.from_pretrained(
27
+ model_name=LORA_ADAPTER_PATH,
28
+ dtype=torch.float32,
29
+ load_in_4bit=False, # CPU mode
30
+ )
31
+ print("βœ… Your fine-tuned grammar model is ready!")
32
+
33
+ # 2. Load the gender verifier model
34
+ print(f"Loading gender model: {GENDER_MODEL_PATH}")
35
  gender_tokenizer = AutoTokenizer.from_pretrained(GENDER_MODEL_PATH)
36
  gender_model = AutoModelForCausalLM.from_pretrained(GENDER_MODEL_PATH).to(device)
37
  print("βœ… Gender verifier model loaded successfully!")
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  except Exception as e:
40
  print(f"❌ Critical error during model loading: {e}")
41
  grammar_model = None
 
43
 
44
  print("--- Model Loading Complete ---")
45
 
 
46
  # --- FastAPI Application Setup ---
47
  app = FastAPI(title="Text Correction API")
48
 
 
53
  original_text: str
54
  corrected_text: str
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # --- API Endpoints ---
57
  @app.post("/correct_grammar", response_model=CorrectionResponse)
58
  async def handle_grammar_correction(request: CorrectionRequest):
59
+ if not grammar_model:
60
  raise HTTPException(status_code=503, detail="Grammar model is not available.")
61
+
62
  prompt_text = request.text
63
  input_text = f"Prompt: {prompt_text}\nResponse:"
64
  inputs = grammar_tokenizer(input_text, return_tensors="pt").to(device)
65
+
66
+ output_ids = grammar_model.generate(**inputs, max_new_tokens=256, do_sample=False)
67
  output_text = grammar_tokenizer.decode(output_ids[0], skip_special_tokens=True)
68
+
69
+ corrected = output_text.split("Response:")[-1].strip()
70
  return CorrectionResponse(original_text=prompt_text, corrected_text=corrected)
71
 
72
  @app.post("/correct_gender", response_model=CorrectionResponse)
73
  async def handle_gender_correction(request: CorrectionRequest):
74
+ if not gender_model:
75
  raise HTTPException(status_code=503, detail="Gender model is not available.")
76
+
77
  prompt_text = request.text
78
  input_text = f"Prompt: Please rewrite the sentence with correct grammar and gender. Output ONLY the corrected sentence:\n{prompt_text}\nResponse:"
79
  inputs = gender_tokenizer(input_text, return_tensors="pt").to(device)
80
+ output_ids = gender_model.generate(**inputs, max_new_tokens=256, do_sample=False)
 
 
 
81
  output_text = gender_tokenizer.decode(output_ids[0], skip_special_tokens=True)
82
+
83
+ cleaned_from_model = output_text.split("Response:")[-1].strip().strip('"')
84
+
85
+ # Regex safety net
86
+ corrections = {
87
+ r'\bher wife\b': 'her husband', r'\bhis husband\b': 'his wife',
88
+ r'\bhe is a girl\b': 'he is a boy', r'\bshe is a boy\b': 'she is a girl'
89
+ }
90
+ for pattern, replacement in corrections.items():
91
+ cleaned_from_model = re.sub(pattern, replacement, cleaned_from_model, flags=re.IGNORECASE)
92
+
93
+ return CorrectionResponse(original_text=prompt_text, corrected_text=cleaned_from_model)
94
 
95
  @app.get("/")
96
  def read_root():
local_test.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+ import os
4
+ import textract
5
+ from fpdf import FPDF
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ from peft import PeftModel
8
+
9
+ # --- Configuration ---
10
+ # All paths are now local
11
+ INPUT_DOC_PATH = "Doreen.doc"
12
+ OUTPUT_PDF_PATH = "Doreen_DeFio_Report_Local_Test.pdf"
13
+
14
+ # --- Model Paths (loading from local Hugging Face cache) ---
15
+ GENDER_MODEL_PATH = "google/gemma-3-270m-qat-q4_0-unquantized"
16
+ BASE_MODEL_PATH = "unsloth/gemma-2b-it"
17
+ # FIX: This now points to the local folder containing your fine-tuned model.
18
+ LORA_ADAPTER_PATH = "gemma-grammar-lora"
19
+
20
+ # --- Global variables for models ---
21
+ grammar_model = None
22
+ grammar_tokenizer = None
23
+ gender_model = None
24
+ gender_tokenizer = None
25
+ device = "cpu"
26
+
27
+ # --- 1. Model Loading Logic (from main.py) ---
28
+ def load_all_models():
29
+ """Loads all AI models into memory."""
30
+ global grammar_model, grammar_tokenizer, gender_model, gender_tokenizer
31
+ print("--- Starting Model Loading ---")
32
+ try:
33
+ print(f"Loading gender model from cache: {GENDER_MODEL_PATH}")
34
+ gender_tokenizer = AutoTokenizer.from_pretrained(GENDER_MODEL_PATH)
35
+ gender_model = AutoModelForCausalLM.from_pretrained(GENDER_MODEL_PATH).to(device)
36
+ print("βœ… Gender verifier model loaded successfully!")
37
+
38
+ print(f"Loading base model for grammar correction from cache: {BASE_MODEL_PATH}")
39
+ base_model = AutoModelForCausalLM.from_pretrained(
40
+ BASE_MODEL_PATH, dtype=torch.float32
41
+ ).to(device)
42
+ grammar_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
43
+
44
+ print(f"Applying LoRA adapter from local folder: {LORA_ADAPTER_PATH}")
45
+ grammar_model = PeftModel.from_pretrained(base_model, LORA_ADAPTER_PATH).to(device)
46
+ print("βœ… Grammar correction model loaded successfully!")
47
+
48
+ if grammar_tokenizer.pad_token is None:
49
+ grammar_tokenizer.pad_token = grammar_tokenizer.eos_token
50
+ if gender_tokenizer.pad_token is None:
51
+ gender_tokenizer.pad_token = gender_tokenizer.eos_token
52
+
53
+ except Exception as e:
54
+ print(f"❌ Critical error during model loading: {e}")
55
+ return False
56
+
57
+ print("--- Model Loading Complete ---")
58
+ return True
59
+
60
+ # --- 2. Correction Functions (adapted from main.py) ---
61
+ def run_grammar_correction(text: str) -> str:
62
+ """Corrects grammar using the loaded LoRA model."""
63
+ if not grammar_model: return text
64
+ input_text = f"Prompt: {text}\nResponse:"
65
+ inputs = grammar_tokenizer(input_text, return_tensors="pt").to(device)
66
+ output_ids = grammar_model.generate(**inputs, max_new_tokens=64, do_sample=False)
67
+ output_text = grammar_tokenizer.decode(output_ids[0], skip_special_tokens=True)
68
+
69
+ # Cleaning logic
70
+ if "Response:" in output_text:
71
+ parts = output_text.split("Response:")
72
+ if len(parts) > 1: return parts[1].strip()
73
+ return output_text.strip()
74
+
75
+ def run_gender_correction(text: str) -> str:
76
+ """Corrects gender using the loaded gender model and regex."""
77
+ if not gender_model: return text
78
+ input_text = f"Prompt: Please rewrite the sentence with correct grammar and gender. Output ONLY the corrected sentence:\n{text}\nResponse:"
79
+ inputs = gender_tokenizer(input_text, return_tensors="pt").to(device)
80
+ output_ids = gender_model.generate(
81
+ **inputs, max_new_tokens=64, temperature=0.0,
82
+ do_sample=False, eos_token_id=gender_tokenizer.eos_token_id
83
+ )
84
+ output_text = gender_tokenizer.decode(output_ids[0], skip_special_tokens=True)
85
+
86
+ # Cleaning logic
87
+ if "Response:" in output_text:
88
+ parts = output_text.split("Response:")
89
+ if len(parts) > 1: output_text = parts[1].strip()
90
+ cleaned_text = re.sub(r'^(Corrected sentence:|Correct:|Prompt:)\s*', '', output_text, flags=re.IGNORECASE).strip().strip('"')
91
+
92
+ # Regex safety net
93
+ corrections = {
94
+ r'\bher wife\b': 'her husband', r'\bhis husband\b': 'his wife',
95
+ r'\bhe is a girl\b': 'he is a boy', r'\bshe is a boy\b': 'she is a girl'
96
+ }
97
+ for pattern, replacement in corrections.items():
98
+ cleaned_text = re.sub(pattern, replacement, cleaned_text, flags=re.IGNORECASE)
99
+ return cleaned_text
100
+
101
+ # --- 3. Document Processing Logic (from document_pipeline.py) ---
102
+ def extract_text_from_doc(filepath):
103
+ """Extracts all text using textract."""
104
+ try:
105
+ text_bytes = textract.process(filepath)
106
+ return text_bytes.decode('utf-8')
107
+ except Exception as e:
108
+ print(f"Error reading document with textract: {e}")
109
+ return None
110
+
111
+ def parse_and_correct_text(raw_text):
112
+ """Parses text and calls the local correction functions."""
113
+ structured_data = {}
114
+ key_value_pattern = re.compile(r'^\s*(Client Name|Date of Exam|...):s*(.*)', re.IGNORECASE | re.DOTALL) # Abridged for brevity
115
+
116
+ # This is the key change: we call the local functions directly
117
+ # instead of making API requests.
118
+ for line in raw_text.split('\n'):
119
+ # ... (parsing logic) ...
120
+ # Example of calling the function directly:
121
+ # corrected_value = run_grammar_correction(value)
122
+ # final_corrected = run_gender_correction(grammar_corrected)
123
+ pass # Placeholder for the full parsing logic from your script
124
+
125
+ # Dummy data to demonstrate PDF generation
126
+ structured_data['Client Name'] = run_grammar_correction("Morgan & Morgan")
127
+ structured_data['Intake'] = run_gender_correction(run_grammar_correction("The IME physician asked the examinee if he has any issues sleeping. The examinee replied yes."))
128
+
129
+ return structured_data
130
+
131
+ class PDF(FPDF):
132
+ """Custom PDF class with Unicode font support."""
133
+ def header(self):
134
+ self.add_font('DejaVu', 'B', 'DejaVuSans-Bold.ttf', uni=True)
135
+ self.set_font('DejaVu', 'B', 15)
136
+ self.cell(0, 10, 'IME WatchDog Report', 0, 1, 'C')
137
+ self.ln(10)
138
+
139
+ def footer(self):
140
+ self.set_y(-15)
141
+ self.set_font('Helvetica', 'I', 8)
142
+ self.cell(0, 10, f'Page {self.page_no()}', 0, 0, 'C')
143
+
144
+ def generate_pdf(data, output_path):
145
+ """Generates the final PDF report."""
146
+ pdf = PDF()
147
+ pdf.add_font('DejaVu', '', 'DejaVuSans.ttf', uni=True)
148
+ pdf.add_page()
149
+ pdf.set_font('DejaVu', '', 12)
150
+
151
+ for key, value in data.items():
152
+ pdf.set_font('DejaVu', 'B', 12)
153
+ pdf.multi_cell(0, 8, f"{key}:")
154
+ pdf.set_font('DejaVu', '', 12)
155
+ pdf.multi_cell(0, 8, str(value))
156
+ pdf.ln(4)
157
+
158
+ pdf.output(output_path)
159
+ print(f"βœ… Successfully generated PDF report at: {output_path}")
160
+
161
+ # --- Main Execution ---
162
+ if __name__ == "__main__":
163
+ print("--- Starting Local Test Pipeline ---")
164
+
165
+ # 1. Pre-requisite: Make sure models are downloaded.
166
+ # It's assumed you've run download_models.py script locally first.
167
+
168
+ # 2. Load the models into memory
169
+ if load_all_models():
170
+ # 3. Extract raw text from the input document
171
+ raw_text = extract_text_from_doc(INPUT_DOC_PATH)
172
+ if raw_text:
173
+ # 4. Parse and correct the text
174
+ corrected_data = parse_and_correct_text(raw_text)
175
+
176
+ # 5. Generate the final PDF report
177
+ generate_pdf(corrected_data, OUTPUT_PDF_PATH)
178
+
179
+ print("--- Pipeline Finished ---")
180
+