Spaces:
Running
Running
| import json | |
| import argparse | |
| import os | |
| import re | |
| import torch | |
| import torch.nn as nn | |
| from TorchCRF import CRF | |
| from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model, LayoutLMv3Config | |
| from typing import List, Dict, Any, Optional, Union, Tuple | |
| import fitz # PyMuPDF | |
| import numpy as np | |
| import cv2 | |
| from ultralytics import YOLO | |
| import glob | |
| import pytesseract | |
| from PIL import Image | |
| from scipy.signal import find_peaks | |
| from scipy.ndimage import gaussian_filter1d | |
| import sys | |
| import io | |
| import base64 | |
| import tempfile # Recommended for robust temporary file handling | |
| # ============================================================================ | |
| # --- CONFIGURATION AND CONSTANTS --- | |
| # ============================================================================ | |
| # NOTE: Update these paths to match your environment before running! | |
| WEIGHTS_PATH = '/home/dipesh/Downloads/api-mcq/YOLO_MATH/yolo_split_data/runs/detect/math_figure_detector_v3/weights/best.pt' | |
| DEFAULT_LAYOUTLMV3_MODEL_PATH = "checkpoints/layoutlmv3_trained_20251031_102846_recovered.pth" | |
| # DIRECTORY CONFIGURATION | |
| # NOTE: These are now used for temporary data extraction/storage | |
| OCR_JSON_OUTPUT_DIR = './ocr_json_output_final' # Still needed for Phase 1 output | |
| FIGURE_EXTRACTION_DIR = './figure_extraction' | |
| TEMP_IMAGE_DIR = './temp_pdf_images' | |
| # Detection parameters | |
| CONF_THRESHOLD = 0.2 | |
| TARGET_CLASSES = ['figure', 'equation'] | |
| IOU_MERGE_THRESHOLD = 0.4 | |
| IOA_SUPPRESSION_THRESHOLD = 0.7 | |
| LINE_TOLERANCE = 15 | |
| # Global counters for sequential numbering across the entire PDF | |
| GLOBAL_FIGURE_COUNT = 0 | |
| GLOBAL_EQUATION_COUNT = 0 | |
| # LayoutLMv3 Labels | |
| ID_TO_LABEL = { | |
| 0: "O", | |
| 1: "B-QUESTION", 2: "I-QUESTION", | |
| 3: "B-OPTION", 4: "I-OPTION", | |
| 5: "B-ANSWER", 6: "I-ANSWER", | |
| 7: "B-SECTION_HEADING", 8: "I-SECTION_HEADING", | |
| 9: "B-PASSAGE", 10: "I-PASSAGE" | |
| } | |
| NUM_LABELS = len(ID_TO_LABEL) | |
| # ============================================================================ | |
| # --- PHASE 1: YOLO/OCR PREPROCESSING FUNCTIONS (Word Extraction) --- | |
| # --- (Includes all necessary helper functions from the first prompt) --- | |
| # ============================================================================ | |
| def calculate_iou(box1, box2): | |
| x1_a, y1_a, x2_a, y2_a = box1 | |
| x1_b, y1_b, x2_b, y2_b = box2 | |
| x_left = max(x1_a, x1_b) | |
| y_top = max(y1_a, y1_b) | |
| x_right = min(x2_a, x2_b) | |
| y_bottom = min(y2_a, y2_b) | |
| intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) | |
| box_a_area = (x2_a - x1_a) * (y2_a - y1_a) | |
| box_b_area = (x2_b - x1_b) * (y2_b - y1_b) | |
| union_area = float(box_a_area + box_b_area - intersection_area) | |
| return intersection_area / union_area if union_area > 0 else 0 | |
| def calculate_ioa(box1, box2): | |
| x1_a, y1_a, x2_a, y2_a = box1 | |
| x1_b, y1_b, x2_b, y2_b = box2 | |
| x_left = max(x1_a, x1_b) | |
| y_top = max(y1_a, y1_b) | |
| x_right = min(x2_a, x2_b) | |
| y_bottom = min(y2_a, y2_b) | |
| intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) | |
| box_a_area = (x2_a - x1_a) * (y2_a - y1_a) | |
| return intersection_area / box_a_area if box_a_area > 0 else 0 | |
| def merge_overlapping_boxes(detections, iou_threshold): | |
| if not detections: return [] | |
| detections.sort(key=lambda d: d['conf'], reverse=True) | |
| merged_detections = [] | |
| is_merged = [False] * len(detections) | |
| for i in range(len(detections)): | |
| if is_merged[i]: continue | |
| current_box = detections[i]['coords'] | |
| current_class = detections[i]['class'] | |
| merged_x1, merged_y1, merged_x2, merged_y2 = current_box | |
| for j in range(i + 1, len(detections)): | |
| if is_merged[j] or detections[j]['class'] != current_class: continue | |
| other_box = detections[j]['coords'] | |
| iou = calculate_iou(current_box, other_box) | |
| if iou > iou_threshold: | |
| merged_x1 = min(merged_x1, other_box[0]) | |
| merged_y1 = min(merged_y1, other_box[1]) | |
| merged_x2 = max(merged_x2, other_box[2]) | |
| merged_y2 = max(merged_y2, other_box[3]) | |
| is_merged[j] = True | |
| merged_detections.append({ | |
| 'coords': (merged_x1, merged_y1, merged_x2, merged_y2), | |
| 'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf'] | |
| }) | |
| return merged_detections | |
| def pdf_to_images(pdf_path, temp_dir): | |
| print("\n[YOLO/OCR STEP 1.1: PDF CONVERSION]") | |
| try: | |
| doc = fitz.open(pdf_path) | |
| pdf_name = os.path.splitext(os.path.basename(pdf_path))[0] | |
| image_paths = [] | |
| mat = fitz.Matrix(2.0, 2.0) | |
| for page_num in range(doc.page_count): | |
| page = doc.load_page(page_num) | |
| pix = page.get_pixmap(matrix=mat) | |
| img_filename = f"{pdf_name}_page{page_num + 1}.png" | |
| img_path = os.path.join(temp_dir, img_filename) | |
| pix.save(img_path) | |
| image_paths.append(img_path) | |
| doc.close() | |
| print(f" ✅ PDF Conversion complete. {len(image_paths)} images generated.") | |
| return image_paths | |
| except Exception as e: | |
| print(f"❌ ERROR processing PDF {pdf_path}: {e}") | |
| return [] | |
| def preprocess_and_ocr_page(image_path, model, pdf_name, page_num): | |
| global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT | |
| page_filename = os.path.basename(image_path) | |
| original_img = cv2.imread(image_path) | |
| if original_img is None: return None | |
| # --- A. YOLO DETECTION AND MERGING --- | |
| results = model.predict(source=image_path, conf=CONF_THRESHOLD, imgsz=640, verbose=False) | |
| relevant_detections = [] | |
| if results and results[0].boxes: | |
| for box in results[0].boxes: | |
| class_id = int(box.cls[0]) | |
| class_name = model.names[class_id] | |
| if class_name in TARGET_CLASSES: | |
| x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) | |
| relevant_detections.append( | |
| {'coords': (x1, y1, x2, y2), 'y1': y1, 'class': class_name, 'conf': float(box.conf[0])}) | |
| merged_detections = merge_overlapping_boxes(relevant_detections, IOU_MERGE_THRESHOLD) | |
| # --- B. COMPONENT EXTRACTION AND TAGGING --- | |
| component_metadata = [] | |
| for detection in merged_detections: | |
| x1, y1, x2, y2 = detection['coords'] | |
| class_name = detection['class'] | |
| if class_name == 'figure': | |
| GLOBAL_FIGURE_COUNT += 1 | |
| counter = GLOBAL_FIGURE_COUNT | |
| component_word = f"FIGURE{counter}" | |
| elif class_name == 'equation': | |
| GLOBAL_EQUATION_COUNT += 1 | |
| counter = GLOBAL_EQUATION_COUNT | |
| component_word = f"EQUATION{counter}" | |
| else: | |
| continue | |
| component_crop = original_img[y1:y2, x1:x2] | |
| component_filename = f"{pdf_name}_page{page_num}_{class_name}{counter}.png" | |
| cv2.imwrite(os.path.join(FIGURE_EXTRACTION_DIR, component_filename), component_crop) | |
| y_midpoint = (y1 + y2) // 2 | |
| component_metadata.append({ | |
| 'type': class_name, 'word': component_word, | |
| 'bbox': [int(x1), int(y1), int(x2), int(y2)], | |
| 'y0': int(y_midpoint), 'x0': int(x1) | |
| }) | |
| # --- C. TESSERACT OCR --- | |
| try: | |
| pil_img = Image.fromarray(cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)) | |
| hocr_data = pytesseract.image_to_data(pil_img, output_type=pytesseract.Output.DICT) | |
| raw_ocr_output = [] | |
| for i in range(len(hocr_data['level'])): | |
| text = hocr_data['text'][i].strip() | |
| if text and hocr_data['conf'][i] > -1: | |
| x1 = int(hocr_data['left'][i]) | |
| y1 = int(hocr_data['top'][i]) | |
| x2 = x1 + int(hocr_data['width'][i]) | |
| y2 = y1 + int(hocr_data['height'][i]) | |
| raw_ocr_output.append({ | |
| 'type': 'text', 'word': text, 'confidence': float(hocr_data['conf'][i]), | |
| 'bbox': [x1, y1, x2, y2], 'y0': y1, 'x0': x1 | |
| }) | |
| except Exception as e: | |
| print(f" ❌ Tesseract OCR Error on {page_filename}: {e}") | |
| return None | |
| # --- D. OCR CLEANING AND MERGING (Using IoA) --- | |
| items_to_sort = [] | |
| for ocr_word in raw_ocr_output: | |
| is_suppressed = False | |
| for component in component_metadata: | |
| ioa = calculate_ioa(ocr_word['bbox'], component['bbox']) | |
| if ioa > IOA_SUPPRESSION_THRESHOLD: | |
| is_suppressed = True | |
| break | |
| if not is_suppressed: | |
| items_to_sort.append(ocr_word) | |
| items_to_sort.extend(component_metadata) | |
| # --- E. SOPHISTICATED LINE-BASED SORTING --- | |
| items_to_sort.sort(key=lambda x: (x['y0'], x['x0'])) | |
| lines = [] | |
| for item in items_to_sort: | |
| placed = False | |
| for line in lines: | |
| y_ref = min(it['y0'] for it in line) | |
| if abs(y_ref - item['y0']) < LINE_TOLERANCE: | |
| line.append(item) | |
| placed = True | |
| break | |
| if not placed and item['type'] in ['equation', 'figure']: | |
| for line in lines: | |
| y_ref = min(it['y0'] for it in line) | |
| if abs(y_ref - item['y0']) < 20: | |
| line.append(item) | |
| placed = True | |
| break | |
| if not placed: | |
| lines.append([item]) | |
| for line in lines: | |
| line.sort(key=lambda x: x['x0']) | |
| final_output = [] | |
| for line in lines: | |
| for item in line: | |
| data_item = {"word": item["word"], "bbox": item["bbox"], "type": item["type"]} | |
| if 'tag' in item: data_item['tag'] = item['tag'] | |
| if 'confidence' in item: data_item['confidence'] = item['confidence'] | |
| final_output.append(data_item) | |
| return final_output | |
| def get_word_data_for_detection(page: fitz.Page, top_margin_percent=0.10, bottom_margin_percent=0.10) -> list: | |
| word_data = page.get_text("words") | |
| if len(word_data) == 0: | |
| try: | |
| pix = page.get_pixmap(matrix=fitz.Matrix(3, 3)) | |
| img_bytes = pix.tobytes("png") | |
| img = Image.open(io.BytesIO(img_bytes)) | |
| data = pytesseract.image_to_data(img, output_type=pytesseract.Output.DICT) | |
| full_word_data = [] | |
| for i in range(len(data['level'])): | |
| if data['text'][i].strip(): | |
| x1, y1 = data['left'][i] / 3, data['top'][i] / 3 | |
| x2, y2 = x1 + data['width'][i] / 3, y1 + data['height'][i] / 3 | |
| full_word_data.append((data['text'][i], x1, y1, x2, y2)) | |
| word_data = full_word_data | |
| except Exception: | |
| return [] | |
| else: | |
| word_data = [(w[4], w[0], w[1], w[2], w[3]) for w in word_data] | |
| page_height = page.rect.height | |
| y_min = page_height * top_margin_percent | |
| y_max = page_height * (1 - bottom_margin_percent) | |
| return [d for d in word_data if d[2] >= y_min and d[4] <= y_max] | |
| def calculate_x_gutters(word_data: list, params: Dict) -> List[int]: | |
| if not word_data: return [] | |
| x_points = [] | |
| for _, x1, _, x2, _ in word_data: x_points.extend([x1, x2]) | |
| max_x = max(x_points) | |
| bin_size = params['cluster_bin_size'] | |
| num_bins = int(np.ceil(max_x / bin_size)) | |
| hist, bin_edges = np.histogram(x_points, bins=num_bins, range=(0, max_x)) | |
| smoothed_hist = gaussian_filter1d(hist.astype(float), sigma=params['cluster_smoothing']) | |
| inverted_signal = np.max(smoothed_hist) - smoothed_hist | |
| peaks, properties = find_peaks( | |
| inverted_signal, height=0, distance=params['cluster_min_width'] / bin_size | |
| ) | |
| if not peaks.size: return [] | |
| threshold_value = np.percentile(smoothed_hist, params['cluster_threshold_percentile']) | |
| inverted_threshold = np.max(smoothed_hist) - threshold_value | |
| significant_peaks = peaks[properties['peak_heights'] >= inverted_threshold] | |
| separator_x_coords = [int(bin_edges[p]) for p in significant_peaks] | |
| final_separators = [] | |
| prominence_threshold = params['cluster_prominence'] * np.max(smoothed_hist) | |
| for x_coord in separator_x_coords: | |
| bin_idx = np.searchsorted(bin_edges, x_coord) - 1 | |
| window_size = int(params['cluster_min_width'] / bin_size) | |
| left_start, left_end = max(0, bin_idx - window_size), bin_idx | |
| right_start, right_end = bin_idx + 1, min(len(smoothed_hist), bin_idx + 1 + window_size) | |
| if left_end <= left_start or right_end <= right_start: continue | |
| avg_left_density = np.mean(smoothed_hist[left_start:left_end]) | |
| avg_right_density = np.mean(smoothed_hist[right_start:right_end]) | |
| if avg_left_density >= prominence_threshold and avg_right_density >= prominence_threshold: | |
| final_separators.append(x_coord) | |
| return sorted(final_separators) | |
| def detect_column_gutters(pdf_path: str, page_num: int, **params) -> Optional[int]: | |
| try: | |
| doc = fitz.open(pdf_path) | |
| page = doc.load_page(page_num) | |
| word_data = get_word_data_for_detection(page, params.get('top_margin_percent', 0.10), | |
| params.get('bottom_margin_percent', 0.10)) | |
| doc.close() | |
| if not word_data: return None | |
| separators = calculate_x_gutters(word_data, params) | |
| if len(separators) == 1: | |
| return separators[0] | |
| elif len(separators) > 1: | |
| page_width = page.rect.width | |
| center_x = page_width / 2 | |
| return min(separators, key=lambda x: abs(x - center_x)) | |
| return None | |
| except Exception: | |
| return None | |
| def _merge_integrity(all_words_by_page: List[str], all_bboxes_raw: List[List[int]], | |
| column_separator_x: Optional[int]) -> List[List[str]]: | |
| if column_separator_x is None: return [all_words_by_page] | |
| left_column_words, right_column_words = [], [] | |
| for word, bbox_raw in zip(all_words_by_page, all_bboxes_raw): | |
| center_x = (bbox_raw[0] + bbox_raw[2]) / 2 | |
| if center_x < column_separator_x: | |
| left_column_words.append(word) | |
| else: | |
| right_column_words.append(word) | |
| return [c for c in [left_column_words, right_column_words] if c] | |
| def run_single_pdf_preprocessing(pdf_path: str, preprocessed_json_path: str) -> Optional[str]: | |
| """Runs the YOLO/OCR pipeline and returns the path to the combined JSON output.""" | |
| global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT | |
| # Reset globals for a new PDF run | |
| GLOBAL_FIGURE_COUNT = 0 | |
| GLOBAL_EQUATION_COUNT = 0 | |
| print("\n" + "=" * 80) | |
| print("--- 1. STARTING YOLO/OCR PREPROCESSING PIPELINE ---") | |
| print("=" * 80) | |
| if not os.path.exists(pdf_path): | |
| print(f"❌ FATAL ERROR: Input PDF not found at {pdf_path}.") | |
| return None | |
| if not os.path.exists(WEIGHTS_PATH): | |
| print(f"❌ FATAL ERROR: YOLO Weights not found at {WEIGHTS_PATH}.") | |
| return None | |
| # Ensure required directories exist | |
| os.makedirs(os.path.dirname(preprocessed_json_path), exist_ok=True) | |
| os.makedirs(FIGURE_EXTRACTION_DIR, exist_ok=True) | |
| os.makedirs(TEMP_IMAGE_DIR, exist_ok=True) | |
| model = YOLO(WEIGHTS_PATH) | |
| pdf_name = os.path.splitext(os.path.basename(pdf_path))[0] | |
| all_pages_data = [] | |
| image_paths = pdf_to_images(pdf_path, TEMP_IMAGE_DIR) | |
| if not image_paths: | |
| print(f"❌ Pipeline halted. Could not convert any pages from PDF.") | |
| return None | |
| print("\n[STEP 1.2: ITERATING PAGES AND RUNNING YOLO/OCR]") | |
| total_pages_processed = 0 | |
| for i, image_path in enumerate(image_paths): | |
| page_num = i + 1 | |
| print(f" -> Processing Page {page_num}/{len(image_paths)}...") | |
| final_output = preprocess_and_ocr_page(image_path, model, pdf_name, page_num) | |
| if final_output is not None: | |
| page_data = {"page_number": page_num, "data": final_output} | |
| all_pages_data.append(page_data) | |
| total_pages_processed += 1 | |
| else: | |
| print(f" ❌ Skipped page {page_num} due to processing error.") | |
| # --- FINAL SAVE STEP --- | |
| if all_pages_data: | |
| try: | |
| with open(preprocessed_json_path, 'w') as f: | |
| json.dump(all_pages_data, f, indent=4) | |
| print(f"\n ✅ Combined structured OCR JSON saved to: {os.path.basename(preprocessed_json_path)}") | |
| except Exception as e: | |
| print(f"❌ ERROR saving combined JSON output: {e}") | |
| return None | |
| else: | |
| print("❌ WARNING: No page data generated. Halting pipeline.") | |
| return None | |
| print("\n" + "=" * 80) | |
| print(f"--- YOLO/OCR PREPROCESSING COMPLETE ({total_pages_processed} pages processed) ---") | |
| print("=" * 80) | |
| return preprocessed_json_path | |
| # ============================================================================ | |
| # --- PHASE 2: LAYOUTLMV3 INFERENCE FUNCTIONS (Raw BIO Tagging) --- | |
| # ============================================================================ | |
| class LayoutLMv3ForTokenClassification(nn.Module): | |
| def __init__(self, num_labels: int = NUM_LABELS): | |
| super().__init__() | |
| self.num_labels = num_labels | |
| config = LayoutLMv3Config.from_pretrained("microsoft/layoutlmv3-base", num_labels=num_labels) | |
| self.layoutlmv3 = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base", config=config) | |
| self.classifier = nn.Linear(config.hidden_size, num_labels) | |
| self.crf = CRF(num_labels) | |
| self.init_weights() | |
| def init_weights(self): | |
| nn.init.xavier_uniform_(self.classifier.weight) | |
| if self.classifier.bias is not None: nn.init.zeros_(self.classifier.bias) | |
| def forward( | |
| self, input_ids: torch.Tensor, bbox: torch.Tensor, attention_mask: torch.Tensor, | |
| labels: Optional[torch.Tensor] = None, | |
| ) -> Union[torch.Tensor, Tuple[List[List[int]], Any]]: | |
| outputs = self.layoutlmv3(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, return_dict=True) | |
| sequence_output = outputs.last_hidden_state | |
| emissions = self.classifier(sequence_output) | |
| mask = attention_mask.bool() | |
| if labels is not None: | |
| loss = -self.crf(emissions, labels, mask=mask).mean() | |
| return loss | |
| else: | |
| return self.crf.viterbi_decode(emissions, mask=mask) | |
| def run_inference_and_get_raw_words(pdf_path: str, model_path: str, | |
| preprocessed_json_path: str, | |
| column_detection_params: Optional[Dict] = None) -> List[Dict[str, Any]]: | |
| """Runs LayoutLMv3-CRF inference and returns the raw word-level predictions, grouped by page.""" | |
| print("\n" + "=" * 80) | |
| print("--- 2. STARTING LAYOUTLMV3 INFERENCE PIPELINE ---") | |
| print("=" * 80) | |
| tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| try: | |
| model = LayoutLMv3ForTokenClassification(num_labels=NUM_LABELS) | |
| checkpoint = torch.load(model_path, map_location=device) | |
| model_state = checkpoint.get('model_state_dict', checkpoint) | |
| # Fix for potential key mismatch | |
| fixed_state_dict = {key.replace('layoutlm.', 'layoutlmv3.'): value for key, value in model_state.items()} | |
| model.load_state_dict(fixed_state_dict) | |
| model.to(device) | |
| model.eval() | |
| except Exception as e: | |
| print(f"❌ FATAL ERROR during LayoutLMv3 model loading: {e}") | |
| return [] | |
| try: | |
| with open(preprocessed_json_path, 'r', encoding='utf-8') as f: | |
| preprocessed_data = json.load(f) | |
| except Exception as e: | |
| print(f"❌ ERROR loading preprocessed JSON: {e}") | |
| return [] | |
| try: | |
| doc = fitz.open(pdf_path) | |
| except Exception as e: | |
| print(f"❌ ERROR loading PDF file: {e}") | |
| return [] | |
| final_page_predictions = [] | |
| CHUNK_SIZE = 500 | |
| for page_data in preprocessed_data: | |
| page_num_1_based = page_data['page_number'] | |
| page_num_0_based = page_num_1_based - 1 | |
| page_raw_predictions = [] | |
| fitz_page = doc.load_page(page_num_0_based) | |
| page_width, page_height = fitz_page.rect.width, fitz_page.rect.height | |
| words, bboxes_raw_pdf_space, normalized_bboxes_list = [], [], [] | |
| scale_factor = 2.0 | |
| for item in page_data['data']: | |
| word, raw_yolo_bbox = item['word'], item['bbox'] | |
| bbox_pdf = [ | |
| int(raw_yolo_bbox[0] / scale_factor), int(raw_yolo_bbox[1] / scale_factor), | |
| int(raw_yolo_bbox[2] / scale_factor), int(raw_yolo_bbox[3] / scale_factor) | |
| ] | |
| normalized_bbox = [ | |
| max(0, min(1000, int(1000 * bbox_pdf[0] / page_width))), | |
| max(0, min(1000, int(1000 * bbox_pdf[1] / page_height))), | |
| max(0, min(1000, int(1000 * bbox_pdf[2] / page_width))), | |
| max(0, min(1000, int(1000 * bbox_pdf[3] / page_height))) | |
| ] | |
| words.append(word) | |
| bboxes_raw_pdf_space.append(bbox_pdf) | |
| normalized_bboxes_list.append(normalized_bbox) | |
| if not words: continue | |
| column_detection_params = column_detection_params or {} | |
| column_separator_x = detect_column_gutters(pdf_path, page_num_0_based, **column_detection_params) | |
| word_chunks = _merge_integrity(words, bboxes_raw_pdf_space, column_separator_x) | |
| # Reworked indexing logic to handle words correctly across chunks and sub-batches | |
| current_global_index = 0 | |
| for chunk_words_original in word_chunks: | |
| if not chunk_words_original: continue | |
| # Reconstruct the aligned chunk of words and bboxes using the global list | |
| chunk_words, chunk_normalized_bboxes, chunk_bboxes_pdf = [], [], [] | |
| temp_global_index = current_global_index | |
| for i in range(len(words)): | |
| if temp_global_index <= i and words[i] in chunk_words_original: | |
| # Simple (non-perfect) way to try and grab the words in order from the global list | |
| # The original script had more complex logic to re-align after splitting. | |
| # For simplicity, we assume 'words' list matches the combined word order from page_data['data']. | |
| if words[i] == chunk_words_original[len(chunk_words)]: | |
| chunk_words.append(words[i]) | |
| chunk_normalized_bboxes.append(normalized_bboxes_list[i]) | |
| chunk_bboxes_pdf.append(bboxes_raw_pdf_space[i]) | |
| current_global_index = i + 1 | |
| if len(chunk_words) == len(chunk_words_original): | |
| break | |
| # --- Inference in sub-batches --- | |
| for i in range(0, len(chunk_words), CHUNK_SIZE): | |
| sub_words = chunk_words[i:i + CHUNK_SIZE] | |
| sub_bboxes = chunk_normalized_bboxes[i:i + CHUNK_SIZE] | |
| sub_bboxes_pdf = chunk_bboxes_pdf[i:i + CHUNK_SIZE] | |
| # Handling empty input if chunking logic was flawed | |
| if not sub_words: continue | |
| encoded_input = tokenizer( | |
| sub_words, boxes=sub_bboxes, truncation=True, padding="max_length", | |
| max_length=512, return_tensors="pt" | |
| ) | |
| input_ids = encoded_input['input_ids'].to(device) | |
| bbox = encoded_input['bbox'].to(device) | |
| attention_mask = encoded_input['attention_mask'].to(device) | |
| with torch.no_grad(): | |
| predictions_int_list = model(input_ids, bbox, attention_mask) | |
| if not predictions_int_list: continue | |
| predictions_int = predictions_int_list[0] | |
| word_ids = encoded_input.word_ids() | |
| word_idx_to_pred_id = {} | |
| for token_idx, word_idx in enumerate(word_ids): | |
| if word_idx is not None and word_idx < len(sub_words): | |
| # Use the prediction for the first token of a word | |
| if word_idx not in word_idx_to_pred_id: | |
| word_idx_to_pred_id[word_idx] = predictions_int[token_idx] | |
| for current_word_idx in range(len(sub_words)): | |
| pred_id_or_tensor = word_idx_to_pred_id.get(current_word_idx, 0) | |
| pred_id = pred_id_or_tensor.item() if torch.is_tensor(pred_id_or_tensor) else pred_id_or_tensor | |
| predicted_label = ID_TO_LABEL[pred_id] | |
| page_raw_predictions.append({ | |
| "word": sub_words[current_word_idx], | |
| "bbox": sub_bboxes_pdf[current_word_idx], | |
| "predicted_label": predicted_label, | |
| "page_number": page_num_1_based | |
| }) | |
| # Ensure the current_global_index is correctly advanced beyond the words in this chunk | |
| # (Implicitly handled by the logic inside the inner loop, but dangerous. The original script's | |
| # way of handling the current_original_index was slightly better but complicated the loop) | |
| if page_raw_predictions: | |
| final_page_predictions.append({ | |
| "page_number": page_num_1_based, | |
| "data": page_raw_predictions | |
| }) | |
| doc.close() | |
| print(f"✅ LayoutLMv3 inference complete. Predicted tags for {len(final_page_predictions)} pages.") | |
| return final_page_predictions | |
| # ============================================================================ | |
| # --- PHASE 3: BIO TO STRUCTURED JSON DECODER (Modified for In-Memory Return) --- | |
| # ============================================================================ | |
| def convert_bio_to_structured_json_relaxed(input_path: str, output_path: str) -> Optional[List[Dict[str, Any]]]: | |
| """ | |
| Reads the page-grouped raw word predictions from input_path, flattens them, and converts | |
| the BIO tags into the structured JSON format. Returns the structured data. | |
| """ | |
| print("\n" + "=" * 80) | |
| print("--- 3. STARTING BIO TO STRUCTURED JSON DECODING ---") | |
| print("=" * 80) | |
| try: | |
| with open(input_path, 'r', encoding='utf-8') as f: | |
| predictions_by_page = json.load(f) | |
| except (json.JSONDecodeError, FileNotFoundError) as e: | |
| print(f"❌ Error loading raw prediction file '{input_path}': {e}") | |
| return None | |
| except Exception as e: | |
| print(f"❌ An unexpected error occurred during file loading: {e}") | |
| return None | |
| # FLATTEN THE LIST OF WORDS ACROSS ALL PAGES | |
| predictions = [] | |
| for page_item in predictions_by_page: | |
| if isinstance(page_item, dict) and 'data' in page_item and isinstance(page_item['data'], list): | |
| predictions.extend(page_item['data']) | |
| if not predictions: | |
| print("❌ Error: No valid word data found in the input file after attempting to flatten pages.") | |
| return None | |
| # --- Your original parsing logic starts here --- | |
| structured_data = [] | |
| current_item = None | |
| current_option_key = None | |
| current_passage_buffer = [] | |
| current_text_buffer = [] | |
| first_question_started = False | |
| last_entity_type = None | |
| just_finished_i_option = False | |
| is_in_new_passage = False | |
| def finalize_passage_to_item(item, passage_buffer): | |
| if passage_buffer: | |
| passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip() | |
| if item.get('passage'): | |
| item['passage'] += ' ' + passage_text | |
| else: | |
| item['passage'] = passage_text | |
| passage_buffer.clear() | |
| for item in predictions: | |
| word = item['word'] | |
| label = item['predicted_label'] | |
| entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None | |
| current_text_buffer.append(word) | |
| previous_entity_type = last_entity_type | |
| is_passage_label = (label == 'B-PASSAGE' or label == 'I-PASSAGE') | |
| if not first_question_started and label != 'B-QUESTION' and not is_passage_label: | |
| just_finished_i_option = False | |
| is_in_new_passage = False | |
| continue | |
| if not first_question_started and is_passage_label: | |
| if label == 'B-PASSAGE' or label == 'I-PASSAGE' or not current_passage_buffer: | |
| current_passage_buffer.append(word) | |
| last_entity_type = 'PASSAGE' | |
| just_finished_i_option = False | |
| is_in_new_passage = False | |
| continue | |
| if label == 'B-QUESTION': | |
| if not first_question_started: | |
| header_text = ' '.join(current_text_buffer[:-1]).strip() | |
| if header_text or current_passage_buffer: | |
| metadata_item = {'type': 'METADATA', 'passage': ''} | |
| if current_passage_buffer: | |
| finalize_passage_to_item(metadata_item, current_passage_buffer) | |
| if header_text: | |
| metadata_item['text'] = header_text | |
| elif header_text: | |
| metadata_item['text'] = header_text | |
| structured_data.append(metadata_item) | |
| first_question_started = True | |
| current_text_buffer = [word] | |
| if current_item is not None: | |
| finalize_passage_to_item(current_item, current_passage_buffer) | |
| current_item['text'] = ' '.join(current_text_buffer[:-1]).strip() | |
| structured_data.append(current_item) | |
| current_text_buffer = [word] | |
| current_item = { | |
| 'question': word, | |
| 'options': {}, | |
| 'answer': '', | |
| 'passage': '', | |
| 'text': '' | |
| } | |
| current_option_key = None | |
| last_entity_type = 'QUESTION' | |
| just_finished_i_option = False | |
| is_in_new_passage = False | |
| continue | |
| if current_item is not None: | |
| if is_in_new_passage: | |
| current_item['new_passage'] += f' {word}' | |
| if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'): | |
| is_in_new_passage = False | |
| if label.startswith(('B-', 'I-')): | |
| last_entity_type = entity_type | |
| continue | |
| is_in_new_passage = False | |
| if label.startswith('B-'): | |
| if entity_type != 'PASSAGE': | |
| finalize_passage_to_item(current_item, current_passage_buffer) | |
| current_passage_buffer = [] | |
| last_entity_type = entity_type | |
| if entity_type == 'PASSAGE': | |
| if previous_entity_type == 'OPTION' and just_finished_i_option: | |
| current_item['new_passage'] = word | |
| is_in_new_passage = True | |
| else: | |
| current_passage_buffer.append(word) | |
| elif entity_type == 'OPTION': | |
| current_option_key = word | |
| current_item['options'][current_option_key] = word | |
| just_finished_i_option = False | |
| elif entity_type == 'ANSWER': | |
| current_item['answer'] = word | |
| current_option_key = None | |
| just_finished_i_option = False | |
| elif entity_type == 'QUESTION': | |
| current_item['question'] += f' {word}' | |
| just_finished_i_option = False | |
| elif label.startswith('I-'): | |
| if entity_type == 'QUESTION' and current_item.get('question'): | |
| current_item['question'] += f' {word}' | |
| last_entity_type = 'QUESTION' | |
| just_finished_i_option = False | |
| elif entity_type == 'PASSAGE': | |
| if previous_entity_type == 'OPTION' and just_finished_i_option: | |
| current_item['new_passage'] = word | |
| is_in_new_passage = True | |
| else: | |
| if last_entity_type == 'QUESTION' and current_item.get('question'): | |
| last_entity_type = 'PASSAGE' | |
| if last_entity_type == 'PASSAGE' or not current_passage_buffer: | |
| current_passage_buffer.append(word) | |
| last_entity_type = 'PASSAGE' | |
| just_finished_i_option = False | |
| elif entity_type == 'OPTION' and last_entity_type == 'OPTION' and current_option_key is not None: | |
| current_item['options'][current_option_key] += f' {word}' | |
| just_finished_i_option = True | |
| elif entity_type == 'ANSWER' and last_entity_type == 'ANSWER': | |
| current_item['answer'] += f' {word}' | |
| just_finished_i_option = False | |
| else: | |
| just_finished_i_option = False | |
| elif label == 'O': | |
| if last_entity_type == 'QUESTION' and current_item and 'question' in current_item: | |
| current_item['question'] += f' {word}' | |
| just_finished_i_option = False | |
| # --- Finalize last item --- | |
| if current_item is not None: | |
| finalize_passage_to_item(current_item, current_passage_buffer) | |
| current_item['text'] = ' '.join(current_text_buffer).strip() | |
| structured_data.append(current_item) | |
| elif not structured_data and current_passage_buffer: | |
| metadata_item = {'type': 'METADATA', 'passage': ''} | |
| finalize_passage_to_item(metadata_item, current_passage_buffer) | |
| metadata_item['text'] = ' '.join(current_text_buffer).strip() | |
| structured_data.append(metadata_item) | |
| # --- FINAL CLEANUP --- | |
| for item in structured_data: | |
| item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip() | |
| if 'new_passage' in item: | |
| item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip() | |
| # --- SAVE INTERMEDIATE FILE (Optional for Debugging) --- | |
| try: | |
| with open(output_path, 'w', encoding='utf-8') as f: | |
| json.dump(structured_data, f, indent=2, ensure_ascii=False) | |
| print(f"✅ Decoding complete. Intermediate structured JSON saved to '{output_path}'.") | |
| except Exception as e: | |
| print(f"❌ Error saving intermediate output file: {e}. Returning data anyway.") | |
| # **KEY CHANGE: RETURN THE DATA STRUCTURE** | |
| return structured_data | |
| # ============================================================================ | |
| # --- PHASE 4: IMAGE EMBEDDING (Modified for In-Memory Return) --- | |
| # ============================================================================ | |
| def get_base64_for_file(filepath: str) -> str: | |
| """Reads a file and returns its Base64 encoded string.""" | |
| try: | |
| with open(filepath, 'rb') as f: | |
| return base64.b64encode(f.read()).decode('utf-8') | |
| except Exception as e: | |
| print(f" ❌ Error encoding file {filepath}: {e}") | |
| return "" | |
| def embed_images_as_base64_in_memory(structured_data: List[Dict[str, Any]], figure_extraction_dir: str) -> List[ | |
| Dict[str, Any]]: | |
| """ | |
| Scans structured data for EQUATION/FIGURE tags, converts corresponding images | |
| to Base64, and embeds them into the JSON entry in memory. | |
| """ | |
| print("\n" + "=" * 80) | |
| print("--- 4. STARTING IMAGE EMBEDDING (Base64) ---") | |
| print("=" * 80) | |
| if not structured_data: | |
| print("❌ Error: No structured data provided for image embedding.") | |
| return [] | |
| # Map image tags (e.g., EQUATION9) to their full file paths | |
| image_files = glob.glob(os.path.join(figure_extraction_dir, "*.png")) | |
| image_lookup = {} | |
| tag_regex = re.compile(r'(figure|equation)(\d+)', re.IGNORECASE) | |
| for filepath in image_files: | |
| filename = os.path.basename(filepath) | |
| match = re.search(r'_(figure|equation)(\d+)\.png$', filename, re.IGNORECASE) | |
| if match: | |
| key = f"{match.group(1).upper()}{match.group(2)}" | |
| image_lookup[key] = filepath | |
| print(f" -> Found {len(image_lookup)} image components in the extraction directory.") | |
| # 2. Iterate through structured data and embed images | |
| final_structured_data = [] | |
| for item in structured_data: | |
| text_fields = [item.get('question', ''), item.get('passage', '')] | |
| if 'options' in item: | |
| for opt_val in item['options'].values(): | |
| text_fields.append(opt_val) | |
| if 'new_passage' in item: | |
| text_fields.append(item['new_passage']) | |
| unique_tags_to_embed = set() | |
| for text in text_fields: | |
| if not text: continue | |
| for match in tag_regex.finditer(text): | |
| tag = match.group(0).upper() | |
| if tag in image_lookup: | |
| unique_tags_to_embed.add(tag) | |
| # 3. Embed the Base64 images | |
| for tag in sorted(list(unique_tags_to_embed)): | |
| filepath = image_lookup[tag] | |
| base64_code = get_base64_for_file(filepath) | |
| base_key = tag.replace(' ', '').lower() | |
| item[base_key] = base64_code | |
| final_structured_data.append(item) | |
| print(f"✅ Image embedding complete. Returning final structured data.") | |
| return final_structured_data | |
| # ============================================================================ | |
| # --- MAIN FUNCTION (The Callable Interface) --- | |
| # ============================================================================ | |
| def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str) -> Optional[List[Dict[str, Any]]]: | |
| """ | |
| Executes the full document analysis pipeline: YOLO/OCR -> LayoutLMv3 -> Structured JSON -> Base64 Image Embed. | |
| Args: | |
| input_pdf_path: Path to the input PDF file. | |
| layoutlmv3_model_path: Path to the saved LayoutLMv3-CRF PyTorch model checkpoint. | |
| Returns: | |
| The final structured JSON data as a Python list of dictionaries, or None on failure. | |
| """ | |
| if not os.path.exists(input_pdf_path): | |
| print(f"❌ FATAL ERROR: Input PDF not found at {input_pdf_path}.") | |
| return None | |
| if not os.path.exists(layoutlmv3_model_path): | |
| print(f"❌ FATAL ERROR: LayoutLMv3 Model checkpoint not found at {layoutlmv3_model_path}.") | |
| return None | |
| if not os.path.exists(WEIGHTS_PATH): | |
| print(f"❌ FATAL ERROR: YOLO Model weights not found at {WEIGHTS_PATH}. Update WEIGHTS_PATH in the script.") | |
| return None | |
| print("\n" + "#" * 80) | |
| print("### STARTING FULL DOCUMENT ANALYSIS PIPELINE ###") | |
| print("#" * 80) | |
| # --- Setup Temporary Directories --- | |
| # Using tempfile module is best practice, but for simplicity we stick to the local setup | |
| pdf_name = os.path.splitext(os.path.basename(input_pdf_path))[0] | |
| temp_pipeline_dir = os.path.join(tempfile.gettempdir(), f"pipeline_run_{pdf_name}_{os.getpid()}") | |
| os.makedirs(temp_pipeline_dir, exist_ok=True) | |
| # Define intermediate file paths inside the temp directory | |
| preprocessed_json_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_preprocessed.json") | |
| raw_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_raw_predictions.json") | |
| structured_intermediate_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_structured_intermediate.json") | |
| # Column Detection Parameters | |
| column_params = { | |
| 'top_margin_percent': 0.10, 'bottom_margin_percent': 0.10, 'cluster_prominence': 0.70, | |
| 'cluster_bin_size': 5, 'cluster_smoothing': 2, 'cluster_threshold_percentile': 30, | |
| 'cluster_min_width': 25, | |
| } | |
| final_result = None | |
| try: | |
| # --- A. PHASE 1: YOLO/OCR PREPROCESSING --- | |
| # Saves figure/equation images to FIGURE_EXTRACTION_DIR and OCR data to preprocessed_json_path | |
| preprocessed_json_path_out = run_single_pdf_preprocessing(input_pdf_path, preprocessed_json_path) | |
| if not preprocessed_json_path_out: | |
| print("Pipeline aborted after Phase 1.") | |
| return None | |
| # --- B. PHASE 2: LAYOUTLMV3 INFERENCE (Raw Output) --- | |
| page_raw_predictions_list = run_inference_and_get_raw_words( | |
| input_pdf_path, | |
| layoutlmv3_model_path, | |
| preprocessed_json_path_out, | |
| column_detection_params=column_params | |
| ) | |
| if not page_raw_predictions_list: | |
| print("Pipeline aborted: No raw predictions generated in Phase 2.") | |
| return None | |
| # Save raw predictions (required input for Phase 3 via file path) | |
| with open(raw_output_path, 'w', encoding='utf-8') as f: | |
| json.dump(page_raw_predictions_list, f, indent=4) | |
| # --- C. PHASE 3: BIO TO STRUCTURED JSON DECODING --- | |
| structured_data_list = convert_bio_to_structured_json_relaxed( | |
| raw_output_path, | |
| structured_intermediate_output_path | |
| ) | |
| if not structured_data_list: | |
| print("Pipeline aborted: Failed to convert BIO tags to structured data in Phase 3.") | |
| return None | |
| # --- D. PHASE 4: IMAGE EMBEDDING (Base64) --- | |
| final_result = embed_images_as_base64_in_memory( | |
| structured_data_list, | |
| FIGURE_EXTRACTION_DIR | |
| ) | |
| except Exception as e: | |
| print(f"❌ FATAL ERROR during pipeline execution: {e}", file=sys.stderr) | |
| return None | |
| finally: | |
| # --- E. Cleanup --- | |
| # Note: In a real environment, you'd be careful about FIGURE_EXTRACTION_DIR, | |
| # but the temporary PDF images and pipeline files should be cleaned up. | |
| try: | |
| # Clean up temp images from Phase 1 | |
| for f in glob.glob(os.path.join(TEMP_IMAGE_DIR, '*')): os.remove(f) | |
| os.rmdir(TEMP_IMAGE_DIR) | |
| except Exception: | |
| pass # Ignore cleanup errors | |
| try: | |
| # Clean up temporary pipeline directory | |
| for f in glob.glob(os.path.join(temp_pipeline_dir, '*')): os.remove(f) | |
| os.rmdir(temp_pipeline_dir) | |
| except Exception: | |
| pass | |
| # --- F. FINAL STATUS --- | |
| print("\n" + "#" * 80) | |
| print("### FULL PIPELINE EXECUTION COMPLETE ###") | |
| print(f"Returning final structured data for {pdf_name}.") | |
| print("#" * 80) | |
| return final_result | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Complete Document Analysis Pipeline (YOLO/OCR -> LayoutLMv3 -> Structured JSON -> Base64 Image Embed).") | |
| parser.add_argument("--input_pdf", type=str, required=True, | |
| help="Path to the input PDF file for analysis.") | |
| parser.add_argument("--layoutlmv3_model_path", type=str, | |
| default=DEFAULT_LAYOUTLMV3_MODEL_PATH, | |
| help="Path to the saved LayoutLMv3-CRF PyTorch model checkpoint.") | |
| args = parser.parse_args() | |
| # --- Call the main function --- | |
| final_json_data = run_document_pipeline(args.input_pdf, args.layoutlmv3_model_path) | |
| if final_json_data: | |
| # Example of what to do with the returned data: Save it to a file | |
| output_file_name = os.path.splitext(os.path.basename(args.input_pdf))[0] + "_final_output_embedded.json" | |
| # Determine where to save the final output (e.g., current directory) | |
| final_output_path = os.path.abspath(output_file_name) | |
| with open(final_output_path, 'w', encoding='utf-8') as f: | |
| json.dump(final_json_data, f, indent=2, ensure_ascii=False) | |
| print(f"\n✅ Final structured data successfully returned and saved to: {final_output_path}") |