import json import argparse import os import re import torch import torch.nn as nn from TorchCRF import CRF from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model, LayoutLMv3Config from typing import List, Dict, Any, Optional, Union, Tuple import fitz # PyMuPDF import numpy as np import cv2 from ultralytics import YOLO import glob import pytesseract from PIL import Image from scipy.signal import find_peaks from scipy.ndimage import gaussian_filter1d import sys import io import base64 import tempfile import time import shutil from sklearn.feature_extraction.text import CountVectorizer from sklearn.metrics.pairwise import cosine_similarity # ============================================================================ # --- CONFIGURATION AND CONSTANTS --- # ============================================================================ # NOTE: Update these paths to match your environment before running! WEIGHTS_PATH = 'YOLO_MATH/yolo_split_data/runs/detect/math_figure_detector_v3/weights/best.pt' DEFAULT_LAYOUTLMV3_MODEL_PATH = "97.pth" # DIRECTORY CONFIGURATION OCR_JSON_OUTPUT_DIR = './ocr_json_output_final' FIGURE_EXTRACTION_DIR = './figure_extraction' TEMP_IMAGE_DIR = './temp_pdf_images' # Detection parameters CONF_THRESHOLD = 0.2 TARGET_CLASSES = ['figure', 'equation'] IOU_MERGE_THRESHOLD = 0.4 IOA_SUPPRESSION_THRESHOLD = 0.7 LINE_TOLERANCE = 15 #Similarity SIMILARITY_THRESHOLD = 0.10 RESOLUTION_MARGIN = 0.05 # Global counters for sequential numbering across the entire PDF GLOBAL_FIGURE_COUNT = 0 GLOBAL_EQUATION_COUNT = 0 # LayoutLMv3 Labels ID_TO_LABEL = { 0: "O", 1: "B-QUESTION", 2: "I-QUESTION", 3: "B-OPTION", 4: "I-OPTION", 5: "B-ANSWER", 6: "I-ANSWER", 7: "B-SECTION_HEADING", 8: "I-SECTION_HEADING", 9: "B-PASSAGE", 10: "I-PASSAGE" } NUM_LABELS = len(ID_TO_LABEL) # ============================================================================ # --- PERFORMANCE OPTIMIZATION: OCR CACHE --- # ============================================================================ class OCRCache: """Caches OCR results per page to avoid redundant Tesseract runs.""" def __init__(self): self.cache = {} def get_key(self, pdf_path: str, page_num: int) -> str: return f"{pdf_path}:{page_num}" def has_ocr(self, pdf_path: str, page_num: int) -> bool: return self.get_key(pdf_path, page_num) in self.cache def get_ocr(self, pdf_path: str, page_num: int) -> Optional[list]: return self.cache.get(self.get_key(pdf_path, page_num)) def set_ocr(self, pdf_path: str, page_num: int, ocr_data: list): self.cache[self.get_key(pdf_path, page_num)] = ocr_data def clear(self): self.cache.clear() # Global OCR cache instance _ocr_cache = OCRCache() # ============================================================================ # --- PHASE 1: YOLO/OCR PREPROCESSING FUNCTIONS --- # ============================================================================ def calculate_iou(box1, box2): x1_a, y1_a, x2_a, y2_a = box1 x1_b, y1_b, x2_b, y2_b = box2 x_left = max(x1_a, x1_b) y_top = max(y1_a, y1_b) x_right = min(x2_a, x2_b) y_bottom = min(y2_a, y2_b) intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) box_a_area = (x2_a - x1_a) * (y2_a - y1_a) box_b_area = (x2_b - x1_b) * (y2_b - y1_b) union_area = float(box_a_area + box_b_area - intersection_area) return intersection_area / union_area if union_area > 0 else 0 def calculate_ioa(box1, box2): x1_a, y1_a, x2_a, y2_a = box1 x1_b, y1_b, x2_b, y2_b = box2 x_left = max(x1_a, x1_b) y_top = max(y1_a, y1_b) x_right = min(x2_a, x2_b) y_bottom = min(y2_a, y2_b) intersection_area = max(0, x_right - x_left) * max(0, y_bottom - y_top) box_a_area = (x2_a - x1_a) * (y2_a - y1_a) return intersection_area / box_a_area if box_a_area > 0 else 0 def filter_nested_boxes(detections, ioa_threshold=0.80): """ Removes boxes that are inside larger boxes (Containment Check). Prioritizes keeping the LARGEST box (the 'parent' container). """ if not detections: return [] # 1. Calculate Area for all detections for d in detections: x1, y1, x2, y2 = d['coords'] d['area'] = (x2 - x1) * (y2 - y1) # 2. Sort by Area Descending (Largest to Smallest) # This ensures we process the 'container' first detections.sort(key=lambda x: x['area'], reverse=True) keep_indices = [] is_suppressed = [False] * len(detections) for i in range(len(detections)): if is_suppressed[i]: continue keep_indices.append(i) box_a = detections[i]['coords'] # Compare with all smaller boxes for j in range(i + 1, len(detections)): if is_suppressed[j]: continue box_b = detections[j]['coords'] # Calculate Intersection x_left = max(box_a[0], box_b[0]) y_top = max(box_a[1], box_b[1]) x_right = min(box_a[2], box_b[2]) y_bottom = min(box_a[3], box_b[3]) if x_right < x_left or y_bottom < y_top: intersection = 0 else: intersection = (x_right - x_left) * (y_bottom - y_top) # Calculate IoA (Intersection over Area of the SMALLER box) # Since we sorted by area, 'box_b' (detections[j]) is the smaller one. area_b = detections[j]['area'] if area_b > 0: ioa_small = intersection / area_b # If the small box is > 90% inside the big box, suppress the small one. if ioa_small > ioa_threshold: is_suppressed[j] = True #print(f" [Suppress] Removed nested object inside larger '{detections[i]['class']}'") return [detections[i] for i in keep_indices] def merge_overlapping_boxes(detections, iou_threshold): if not detections: return [] detections.sort(key=lambda d: d['conf'], reverse=True) merged_detections = [] is_merged = [False] * len(detections) for i in range(len(detections)): if is_merged[i]: continue current_box = detections[i]['coords'] current_class = detections[i]['class'] merged_x1, merged_y1, merged_x2, merged_y2 = current_box for j in range(i + 1, len(detections)): if is_merged[j] or detections[j]['class'] != current_class: continue other_box = detections[j]['coords'] iou = calculate_iou(current_box, other_box) if iou > iou_threshold: merged_x1 = min(merged_x1, other_box[0]) merged_y1 = min(merged_y1, other_box[1]) merged_x2 = max(merged_x2, other_box[2]) merged_y2 = max(merged_y2, other_box[3]) is_merged[j] = True merged_detections.append({ 'coords': (merged_x1, merged_y1, merged_x2, merged_y2), 'y1': merged_y1, 'class': current_class, 'conf': detections[i]['conf'] }) return merged_detections def merge_yolo_into_word_data(raw_word_data: list, yolo_detections: list, scale_factor: float) -> list: """ Filters out raw words that are inside YOLO boxes and replaces them with a single solid 'placeholder' block for the column detector. """ if not yolo_detections: return raw_word_data # 1. Convert YOLO boxes (Pixels) to PDF Coordinates (Points) pdf_space_boxes = [] for det in yolo_detections: x1, y1, x2, y2 = det['coords'] pdf_box = ( x1 / scale_factor, y1 / scale_factor, x2 / scale_factor, y2 / scale_factor ) pdf_space_boxes.append(pdf_box) # 2. Filter out raw words that are inside YOLO boxes cleaned_word_data = [] for word_tuple in raw_word_data: wx1, wy1, wx2, wy2 = word_tuple[1], word_tuple[2], word_tuple[3], word_tuple[4] w_center_x = (wx1 + wx2) / 2 w_center_y = (wy1 + wy2) / 2 is_inside_yolo = False for px1, py1, px2, py2 in pdf_space_boxes: if px1 <= w_center_x <= px2 and py1 <= w_center_y <= py2: is_inside_yolo = True break if not is_inside_yolo: cleaned_word_data.append(word_tuple) # 3. Add the YOLO boxes themselves as "Solid Words" for i, (px1, py1, px2, py2) in enumerate(pdf_space_boxes): dummy_entry = (f"BLOCK_{i}", px1, py1, px2, py2) cleaned_word_data.append(dummy_entry) return cleaned_word_data # ============================================================================ # --- MISSING HELPER FUNCTION --- # ============================================================================ def preprocess_image_for_ocr(img_np): """ Converts image to grayscale and applies Otsu's Binarization to separate text from background clearly. """ # 1. Convert to Grayscale if needed if len(img_np.shape) == 3: gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) else: gray = img_np # 2. Apply Otsu's Thresholding (Automatic binary threshold) # This makes text solid black and background solid white _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) return thresh def calculate_vertical_gap_coverage(word_data: list, sep_x: int, page_height: float, gutter_width: int = 10) -> float: """ Calculates what percentage of the page's vertical text span is 'cleanly split' by the separator. A valid column split should split > 65% of the page verticality. """ if not word_data: return 0.0 # Determine the vertical span of the actual text content y_coords = [w[2] for w in word_data] + [w[4] for w in word_data] # y1 and y2 min_y, max_y = min(y_coords), max(y_coords) total_text_height = max_y - min_y if total_text_height <= 0: return 0.0 # Create a boolean array representing the Y-axis (1 pixel per unit) gap_open_mask = np.ones(int(total_text_height) + 1, dtype=bool) zone_left = sep_x - (gutter_width / 2) zone_right = sep_x + (gutter_width / 2) offset_y = int(min_y) for _, x1, y1, x2, y2 in word_data: # Check if this word horizontally interferes with the separator if x2 > zone_left and x1 < zone_right: y_start_idx = max(0, int(y1) - offset_y) y_end_idx = min(len(gap_open_mask), int(y2) - offset_y) if y_end_idx > y_start_idx: gap_open_mask[y_start_idx:y_end_idx] = False open_pixels = np.sum(gap_open_mask) coverage_ratio = open_pixels / len(gap_open_mask) return coverage_ratio def calculate_x_gutters(word_data: list, params: Dict, page_height: float) -> List[int]: """ Calculates X-axis histogram and validates using BRIDGING DENSITY and Vertical Coverage. """ if not word_data: return [] x_points = [] # Use only word_data elements 1 (x1) and 3 (x2) for item in word_data: x_points.extend([item[1], item[3]]) if not x_points: return [] max_x = max(x_points) # 1. Determine total text height for ratio calculation y_coords = [item[2] for item in word_data] + [item[4] for item in word_data] min_y, max_y = min(y_coords), max(y_coords) total_text_height = max_y - min_y if total_text_height <= 0: return [] # Histogram Setup bin_size = params.get('cluster_bin_size', 5) smoothing = params.get('cluster_smoothing', 1) min_width = params.get('cluster_min_width', 20) threshold_percentile = params.get('cluster_threshold_percentile', 85) num_bins = int(np.ceil(max_x / bin_size)) hist, bin_edges = np.histogram(x_points, bins=num_bins, range=(0, max_x)) smoothed_hist = gaussian_filter1d(hist.astype(float), sigma=smoothing) inverted_signal = np.max(smoothed_hist) - smoothed_hist peaks, properties = find_peaks( inverted_signal, height=np.max(inverted_signal) - np.percentile(smoothed_hist, threshold_percentile), distance=min_width / bin_size ) if not peaks.size: return [] separator_x_coords = [int(bin_edges[p]) for p in peaks] final_separators = [] for x_coord in separator_x_coords: # --- CHECK 1: BRIDGING DENSITY (The "Cut Through" Check) --- # Calculate the total vertical height of words that physically cross this line. bridging_height = 0 bridging_count = 0 for item in word_data: wx1, wy1, wx2, wy2 = item[1], item[2], item[3], item[4] # Check if this word physically sits on top of the separator line if wx1 < x_coord and wx2 > x_coord: word_h = wy2 - wy1 bridging_height += word_h bridging_count += 1 # Calculate Ratio: How much of the page's text height is blocked by these crossing words? bridging_ratio = bridging_height / total_text_height # THRESHOLD: If bridging blocks > 8% of page height, REJECT. # This allows for page numbers or headers (usually < 5%) to cross, but NOT paragraphs. if bridging_ratio > 0.08: print(f" ❌ Separator X={x_coord} REJECTED: Bridging Ratio {bridging_ratio:.1%} (>15%) cuts through text.") continue # --- CHECK 2: VERTICAL GAP COVERAGE (The "Clean Split" Check) --- # The gap must exist cleanly for > 65% of the text height. coverage = calculate_vertical_gap_coverage(word_data, x_coord, page_height, gutter_width=min_width) if coverage >= 0.80: final_separators.append(x_coord) print(f" -> Separator X={x_coord} ACCEPTED (Coverage: {coverage:.1%}, Bridging: {bridging_ratio:.1%})") else: print(f" ❌ Separator X={x_coord} REJECTED (Coverage: {coverage:.1%}, Bridging: {bridging_ratio:.1%})") return sorted(final_separators) def get_word_data_for_detection(page: fitz.Page, pdf_path: str, page_num: int, top_margin_percent=0.10, bottom_margin_percent=0.10) -> list: """Extract word data with OCR caching to avoid redundant Tesseract runs.""" word_data = page.get_text("words") if len(word_data) > 0: word_data = [(w[4], w[0], w[1], w[2], w[3]) for w in word_data] else: if _ocr_cache.has_ocr(pdf_path, page_num): word_data = _ocr_cache.get_ocr(pdf_path, page_num) else: try: # --- OPTIMIZATION START --- # 1. Render at Higher Resolution (Zoom 4.0 = ~300 DPI) zoom_level = 4.0 pix = page.get_pixmap(matrix=fitz.Matrix(zoom_level, zoom_level)) # 2. Convert directly to OpenCV format (Faster than PIL) img_np = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n) if pix.n == 3: img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) elif pix.n == 4: img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2BGR) # 3. Apply Preprocessing (Thresholding) processed_img = preprocess_image_for_ocr(img_np) # 4. Optimized Tesseract Config # --psm 6: Assume a single uniform block of text (Great for columns/questions) # --oem 3: Default engine (LSTM) custom_config = r'--oem 3 --psm 6' data = pytesseract.image_to_data(processed_img, output_type=pytesseract.Output.DICT, config=custom_config) full_word_data = [] for i in range(len(data['level'])): text = data['text'][i].strip() if text: # Scale coordinates back to PDF points x1 = data['left'][i] / zoom_level y1 = data['top'][i] / zoom_level x2 = (data['left'][i] + data['width'][i]) / zoom_level y2 = (data['top'][i] + data['height'][i]) / zoom_level full_word_data.append((text, x1, y1, x2, y2)) word_data = full_word_data _ocr_cache.set_ocr(pdf_path, page_num, word_data) # --- OPTIMIZATION END --- except Exception as e: print(f" ❌ OCR Error in detection phase: {e}") return [] # Apply margin filtering page_height = page.rect.height y_min = page_height * top_margin_percent y_max = page_height * (1 - bottom_margin_percent) return [d for d in word_data if d[2] >= y_min and d[4] <= y_max] def pixmap_to_numpy(pix: fitz.Pixmap) -> np.ndarray: img_data = pix.samples img = np.frombuffer(img_data, dtype=np.uint8).reshape(pix.height, pix.width, pix.n) if pix.n == 4: img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR) elif pix.n == 3: img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) return img def extract_native_words_and_convert(fitz_page, scale_factor: float = 2.0) -> list: raw_word_data = fitz_page.get_text("words") converted_ocr_output = [] DEFAULT_CONFIDENCE = 99.0 for x1, y1, x2, y2, word, *rest in raw_word_data: if not word.strip(): continue x1_pix = int(x1 * scale_factor) y1_pix = int(y1 * scale_factor) x2_pix = int(x2 * scale_factor) y2_pix = int(y2 * scale_factor) converted_ocr_output.append({ 'type': 'text', 'word': word, 'confidence': DEFAULT_CONFIDENCE, 'bbox': [x1_pix, y1_pix, x2_pix, y2_pix], 'y0': y1_pix, 'x0': x1_pix }) return converted_ocr_output def preprocess_and_ocr_page(original_img: np.ndarray, model, pdf_path: str, page_num: int, fitz_page: fitz.Page, pdf_name: str) -> Tuple[List[Dict[str, Any]], Optional[int]]: """ OPTIMIZED FLOW: 1. Run YOLO to find Equations/Tables. 2. Mask raw text with YOLO boxes. 3. Run Column Detection on the MASKED data. 4. Proceed with OCR (Native or High-Res Tesseract Fallback) and Output. """ global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT start_time_total = time.time() if original_img is None: print(f" ❌ Invalid image for page {page_num}.") return None, None # ==================================================================== # --- STEP 1: YOLO DETECTION --- # ==================================================================== start_time_yolo = time.time() results = model.predict(source=original_img, conf=CONF_THRESHOLD, imgsz=640, verbose=False) relevant_detections = [] if results and results[0].boxes: for box in results[0].boxes: class_id = int(box.cls[0]) class_name = model.names[class_id] if class_name in TARGET_CLASSES: x1, y1, x2, y2 = box.xyxy[0].cpu().numpy().astype(int) relevant_detections.append( {'coords': (x1, y1, x2, y2), 'y1': y1, 'class': class_name, 'conf': float(box.conf[0])} ) merged_detections = merge_overlapping_boxes(relevant_detections, IOU_MERGE_THRESHOLD) print(f" [LOG] YOLO found {len(merged_detections)} objects in {time.time() - start_time_yolo:.3f}s.") # ==================================================================== # --- STEP 2: PREPARE DATA FOR COLUMN DETECTION (MASKING) --- # ==================================================================== # Note: This uses the updated 'get_word_data_for_detection' which has its own optimizations raw_words_for_layout = get_word_data_for_detection( fitz_page, pdf_path, page_num, top_margin_percent=0.10, bottom_margin_percent=0.10 ) masked_word_data = merge_yolo_into_word_data(raw_words_for_layout, merged_detections, scale_factor=2.0) # ==================================================================== # --- STEP 3: COLUMN DETECTION --- # ==================================================================== page_width_pdf = fitz_page.rect.width page_height_pdf = fitz_page.rect.height column_detection_params = { 'cluster_bin_size': 2, 'cluster_smoothing': 2, 'cluster_min_width': 10, 'cluster_threshold_percentile': 85, } separators = calculate_x_gutters(masked_word_data, column_detection_params, page_height_pdf) page_separator_x = None if separators: central_min = page_width_pdf * 0.35 central_max = page_width_pdf * 0.65 central_separators = [s for s in separators if central_min <= s <= central_max] if central_separators: center_x = page_width_pdf / 2 page_separator_x = min(central_separators, key=lambda x: abs(x - center_x)) print(f" ✅ Column Split Confirmed at X={page_separator_x:.1f}") else: print(" ⚠️ Gutter found off-center. Ignoring.") else: print(" -> Single Column Layout Confirmed.") # ==================================================================== # --- STEP 4: COMPONENT EXTRACTION (Save Images) --- # ==================================================================== start_time_components = time.time() component_metadata = [] fig_count_page = 0 eq_count_page = 0 for detection in merged_detections: x1, y1, x2, y2 = detection['coords'] class_name = detection['class'] if class_name == 'figure': GLOBAL_FIGURE_COUNT += 1 counter = GLOBAL_FIGURE_COUNT component_word = f"FIGURE{counter}" fig_count_page += 1 elif class_name == 'equation': GLOBAL_EQUATION_COUNT += 1 counter = GLOBAL_EQUATION_COUNT component_word = f"EQUATION{counter}" eq_count_page += 1 else: continue component_crop = original_img[y1:y2, x1:x2] component_filename = f"{pdf_name}_page{page_num}_{class_name}{counter}.png" cv2.imwrite(os.path.join(FIGURE_EXTRACTION_DIR, component_filename), component_crop) y_midpoint = (y1 + y2) // 2 component_metadata.append({ 'type': class_name, 'word': component_word, 'bbox': [int(x1), int(y1), int(x2), int(y2)], 'y0': int(y_midpoint), 'x0': int(x1) }) # ==================================================================== # --- STEP 5: HYBRID OCR (Native Text + Cached Tesseract Fallback) --- # ==================================================================== raw_ocr_output = [] scale_factor = 2.0 # Pipeline standard scale try: # Try getting native text first raw_ocr_output = extract_native_words_and_convert(fitz_page, scale_factor=scale_factor) except Exception as e: print(f" ❌ Native text extraction failed: {e}") # If native text is missing, fall back to OCR if not raw_ocr_output: if _ocr_cache.has_ocr(pdf_path, page_num): print(f" ⚡ Using cached Tesseract OCR for page {page_num}") cached_word_data = _ocr_cache.get_ocr(pdf_path, page_num) for word_tuple in cached_word_data: word_text, x1, y1, x2, y2 = word_tuple # Scale from PDF points to Pipeline Pixels (2.0) x1_pix = int(x1 * scale_factor) y1_pix = int(y1 * scale_factor) x2_pix = int(x2 * scale_factor) y2_pix = int(y2 * scale_factor) raw_ocr_output.append({ 'type': 'text', 'word': word_text, 'confidence': 95.0, 'bbox': [x1_pix, y1_pix, x2_pix, y2_pix], 'y0': y1_pix, 'x0': x1_pix }) else: # === START OF OPTIMIZED OCR BLOCK === try: # 1. Re-render Page at High Resolution (Zoom 4.0 = ~300 DPI) # We do this specifically for OCR accuracy, separate from the pipeline image ocr_zoom = 4.0 pix_ocr = fitz_page.get_pixmap(matrix=fitz.Matrix(ocr_zoom, ocr_zoom)) # Convert PyMuPDF Pixmap to OpenCV format img_ocr_np = np.frombuffer(pix_ocr.samples, dtype=np.uint8).reshape(pix_ocr.height, pix_ocr.width, pix_ocr.n) if pix_ocr.n == 3: img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGB2BGR) elif pix_ocr.n == 4: img_ocr_np = cv2.cvtColor(img_ocr_np, cv2.COLOR_RGBA2BGR) # 2. Preprocess (Binarization) # Ensure 'preprocess_image_for_ocr' is defined at top of file! processed_img = preprocess_image_for_ocr(img_ocr_np) # 3. Run Tesseract with Optimized Configuration # --oem 3: Default LSTM engine # --psm 6: Assume a single uniform block of text (Critical for lists/questions) custom_config = r'--oem 3 --psm 6' hocr_data = pytesseract.image_to_data( processed_img, output_type=pytesseract.Output.DICT, config=custom_config ) for i in range(len(hocr_data['level'])): text = hocr_data['text'][i].strip() if text and hocr_data['conf'][i] > -1: # 4. Coordinate Mapping # We scanned at Zoom 4.0, but our pipeline expects Zoom 2.0. # Scale Factor = (Target 2.0) / (Source 4.0) = 0.5 scale_adjustment = scale_factor / ocr_zoom x1 = int(hocr_data['left'][i] * scale_adjustment) y1 = int(hocr_data['top'][i] * scale_adjustment) w = int(hocr_data['width'][i] * scale_adjustment) h = int(hocr_data['height'][i] * scale_adjustment) x2 = x1 + w y2 = y1 + h raw_ocr_output.append({ 'type': 'text', 'word': text, 'confidence': float(hocr_data['conf'][i]), 'bbox': [x1, y1, x2, y2], 'y0': y1, 'x0': x1 }) except Exception as e: print(f" ❌ Tesseract OCR Error: {e}") # === END OF OPTIMIZED OCR BLOCK === # ==================================================================== # --- STEP 6: OCR CLEANING AND MERGING --- # ==================================================================== items_to_sort = [] for ocr_word in raw_ocr_output: is_suppressed = False for component in component_metadata: # Do not include words that are inside figure/equation boxes ioa = calculate_ioa(ocr_word['bbox'], component['bbox']) if ioa > IOA_SUPPRESSION_THRESHOLD: is_suppressed = True break if not is_suppressed: items_to_sort.append(ocr_word) # Add figures/equations back into the flow as "words" items_to_sort.extend(component_metadata) # ==================================================================== # --- STEP 7: LINE-BASED SORTING --- # ==================================================================== items_to_sort.sort(key=lambda x: (x['y0'], x['x0'])) lines = [] for item in items_to_sort: placed = False for line in lines: y_ref = min(it['y0'] for it in line) if abs(y_ref - item['y0']) < LINE_TOLERANCE: line.append(item) placed = True break if not placed and item['type'] in ['equation', 'figure']: for line in lines: y_ref = min(it['y0'] for it in line) if abs(y_ref - item['y0']) < 20: line.append(item) placed = True break if not placed: lines.append([item]) for line in lines: line.sort(key=lambda x: x['x0']) final_output = [] for line in lines: for item in line: data_item = {"word": item["word"], "bbox": item["bbox"], "type": item["type"]} if 'tag' in item: data_item['tag'] = item['tag'] final_output.append(data_item) return final_output, page_separator_x def run_single_pdf_preprocessing(pdf_path: str, preprocessed_json_path: str) -> Optional[str]: global GLOBAL_FIGURE_COUNT, GLOBAL_EQUATION_COUNT GLOBAL_FIGURE_COUNT = 0 GLOBAL_EQUATION_COUNT = 0 _ocr_cache.clear() print("\n" + "=" * 80) print("--- 1. STARTING OPTIMIZED YOLO/OCR PREPROCESSING PIPELINE ---") print("=" * 80) if not os.path.exists(pdf_path): print(f"❌ FATAL ERROR: Input PDF not found at {pdf_path}.") return None os.makedirs(os.path.dirname(preprocessed_json_path), exist_ok=True) os.makedirs(FIGURE_EXTRACTION_DIR, exist_ok=True) model = YOLO(WEIGHTS_PATH) pdf_name = os.path.splitext(os.path.basename(pdf_path))[0] try: doc = fitz.open(pdf_path) print(f"✅ Opened PDF: {pdf_name} ({doc.page_count} pages)") except Exception as e: print(f"❌ ERROR loading PDF file: {e}") return None all_pages_data = [] total_pages_processed = 0 mat = fitz.Matrix(2.0, 2.0) print("\n[STEP 1.2: ITERATING PAGES - IN-MEMORY PROCESSING]") for page_num_0_based in range(doc.page_count): page_num = page_num_0_based + 1 print(f" -> Processing Page {page_num}/{doc.page_count}...") fitz_page = doc.load_page(page_num_0_based) try: pix = fitz_page.get_pixmap(matrix=mat) original_img = pixmap_to_numpy(pix) except Exception as e: print(f" ❌ Error converting page {page_num} to image: {e}") continue final_output, page_separator_x = preprocess_and_ocr_page( original_img, model, pdf_path, page_num, fitz_page, pdf_name ) if final_output is not None: page_data = { "page_number": page_num, "data": final_output, "column_separator_x": page_separator_x } all_pages_data.append(page_data) total_pages_processed += 1 else: print(f" ❌ Skipped page {page_num} due to processing error.") doc.close() if all_pages_data: try: with open(preprocessed_json_path, 'w') as f: json.dump(all_pages_data, f, indent=4) print(f"\n ✅ Combined structured OCR JSON saved to: {os.path.basename(preprocessed_json_path)}") except Exception as e: print(f"❌ ERROR saving combined JSON output: {e}") return None else: print("❌ WARNING: No page data generated. Halting pipeline.") return None print("\n" + "=" * 80) print(f"--- YOLO/OCR PREPROCESSING COMPLETE ({total_pages_processed} pages processed) ---") print("=" * 80) return preprocessed_json_path # ============================================================================ # --- PHASE 2: LAYOUTLMV3 INFERENCE FUNCTIONS --- # ============================================================================ class LayoutLMv3ForTokenClassification(nn.Module): def __init__(self, num_labels: int = NUM_LABELS): super().__init__() self.num_labels = num_labels config = LayoutLMv3Config.from_pretrained("microsoft/layoutlmv3-base", num_labels=num_labels) self.layoutlmv3 = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base", config=config) self.classifier = nn.Linear(config.hidden_size, num_labels) self.crf = CRF(num_labels) self.init_weights() def init_weights(self): nn.init.xavier_uniform_(self.classifier.weight) if self.classifier.bias is not None: nn.init.zeros_(self.classifier.bias) def forward(self, input_ids: torch.Tensor, bbox: torch.Tensor, attention_mask: torch.Tensor, labels: Optional[torch.Tensor] = None): outputs = self.layoutlmv3(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, return_dict=True) sequence_output = outputs.last_hidden_state emissions = self.classifier(sequence_output) mask = attention_mask.bool() if labels is not None: loss = -self.crf(emissions, labels, mask=mask).mean() return loss else: return self.crf.viterbi_decode(emissions, mask=mask) def _merge_integrity(all_token_data: List[Dict[str, Any]], column_separator_x: Optional[int]) -> List[List[Dict[str, Any]]]: """Splits the token data objects into column chunks based on a separator.""" if column_separator_x is None: print(" -> No column separator. Treating as one chunk.") return [all_token_data] left_column_tokens, right_column_tokens = [], [] for token_data in all_token_data: bbox_raw = token_data['bbox_raw_pdf_space'] center_x = (bbox_raw[0] + bbox_raw[2]) / 2 if center_x < column_separator_x: left_column_tokens.append(token_data) else: right_column_tokens.append(token_data) chunks = [c for c in [left_column_tokens, right_column_tokens] if c] print(f" -> Data split into {len(chunks)} column chunk(s) using separator X={column_separator_x}.") return chunks def run_inference_and_get_raw_words(pdf_path: str, model_path: str, preprocessed_json_path: str, column_detection_params: Optional[Dict] = None) -> List[Dict[str, Any]]: print("\n" + "=" * 80) print("--- 2. STARTING LAYOUTLMV3 INFERENCE PIPELINE (Raw Word Output) ---") print("=" * 80) tokenizer = LayoutLMv3TokenizerFast.from_pretrained("microsoft/layoutlmv3-base") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f" -> Using device: {device}") try: model = LayoutLMv3ForTokenClassification(num_labels=NUM_LABELS) checkpoint = torch.load(model_path, map_location=device) model_state = checkpoint.get('model_state_dict', checkpoint) fixed_state_dict = {key.replace('layoutlm.', 'layoutlmv3.'): value for key, value in model_state.items()} model.load_state_dict(fixed_state_dict) model.to(device) model.eval() print(f"✅ LayoutLMv3 Model loaded successfully from {os.path.basename(model_path)}.") except Exception as e: print(f"❌ FATAL ERROR during LayoutLMv3 model loading: {e}") return [] try: with open(preprocessed_json_path, 'r', encoding='utf-8') as f: preprocessed_data = json.load(f) print(f"✅ Loaded preprocessed data with {len(preprocessed_data)} pages.") except Exception: print("❌ Error loading preprocessed JSON.") return [] try: doc = fitz.open(pdf_path) except Exception: print("❌ Error loading PDF.") return [] final_page_predictions = [] CHUNK_SIZE = 500 for page_data in preprocessed_data: page_num_1_based = page_data['page_number'] page_num_0_based = page_num_1_based - 1 page_raw_predictions = [] print(f"\n *** Processing Page {page_num_1_based} ({len(page_data['data'])} raw tokens) ***") fitz_page = doc.load_page(page_num_0_based) page_width, page_height = fitz_page.rect.width, fitz_page.rect.height print(f" -> Page dimensions: {page_width:.0f}x{page_height:.0f} (PDF points).") all_token_data = [] scale_factor = 2.0 for item in page_data['data']: raw_yolo_bbox = item['bbox'] bbox_pdf = [ int(raw_yolo_bbox[0] / scale_factor), int(raw_yolo_bbox[1] / scale_factor), int(raw_yolo_bbox[2] / scale_factor), int(raw_yolo_bbox[3] / scale_factor) ] normalized_bbox = [ max(0, min(1000, int(1000 * bbox_pdf[0] / page_width))), max(0, min(1000, int(1000 * bbox_pdf[1] / page_height))), max(0, min(1000, int(1000 * bbox_pdf[2] / page_width))), max(0, min(1000, int(1000 * bbox_pdf[3] / page_height))) ] all_token_data.append({ "word": item['word'], "bbox_raw_pdf_space": bbox_pdf, "bbox_normalized": normalized_bbox, "item_original_data": item }) if not all_token_data: continue column_separator_x = page_data.get('column_separator_x', None) if column_separator_x is not None: print(f" -> Using SAVED column separator: X={column_separator_x}") else: print(" -> No column separator found. Assuming single chunk.") token_chunks = _merge_integrity(all_token_data, column_separator_x) total_chunks = len(token_chunks) for chunk_idx, chunk_tokens in enumerate(token_chunks): if not chunk_tokens: continue chunk_words = [t['word'] for t in chunk_tokens] chunk_normalized_bboxes = [t['bbox_normalized'] for t in chunk_tokens] total_sub_chunks = (len(chunk_words) + CHUNK_SIZE - 1) // CHUNK_SIZE for i in range(0, len(chunk_words), CHUNK_SIZE): sub_chunk_idx = i // CHUNK_SIZE + 1 sub_words = chunk_words[i:i + CHUNK_SIZE] sub_bboxes = chunk_normalized_bboxes[i:i + CHUNK_SIZE] sub_tokens_data = chunk_tokens[i:i + CHUNK_SIZE] print(f" -> Chunk {chunk_idx + 1}/{total_chunks}, Sub-chunk {sub_chunk_idx}/{total_sub_chunks}: {len(sub_words)} words. Running Inference...") encoded_input = tokenizer( sub_words, boxes=sub_bboxes, truncation=True, padding="max_length", max_length=512, return_tensors="pt" ) input_ids = encoded_input['input_ids'].to(device) bbox = encoded_input['bbox'].to(device) attention_mask = encoded_input['attention_mask'].to(device) with torch.no_grad(): predictions_int_list = model(input_ids, bbox, attention_mask) if not predictions_int_list: continue predictions_int = predictions_int_list[0] word_ids = encoded_input.word_ids() word_idx_to_pred_id = {} for token_idx, word_idx in enumerate(word_ids): if word_idx is not None and word_idx < len(sub_words): if word_idx not in word_idx_to_pred_id: word_idx_to_pred_id[word_idx] = predictions_int[token_idx] for current_word_idx in range(len(sub_words)): pred_id_or_tensor = word_idx_to_pred_id.get(current_word_idx, 0) pred_id = pred_id_or_tensor.item() if torch.is_tensor(pred_id_or_tensor) else pred_id_or_tensor predicted_label = ID_TO_LABEL[pred_id] original_token = sub_tokens_data[current_word_idx] page_raw_predictions.append({ "word": original_token['word'], "bbox": original_token['bbox_raw_pdf_space'], "predicted_label": predicted_label, "page_number": page_num_1_based }) if page_raw_predictions: final_page_predictions.append({ "page_number": page_num_1_based, "data": page_raw_predictions }) print(f" *** Page {page_num_1_based} Finalized: {len(page_raw_predictions)} labeled words. ***") doc.close() print("\n" + "=" * 80) print("--- LAYOUTLMV3 INFERENCE COMPLETE ---") print("=" * 80) return final_page_predictions def create_label_studio_span(page_results, start_idx, end_idx, label): entity_words = [page_results[i]['word'] for i in range(start_idx, end_idx + 1)] entity_bboxes = [page_results[i]['bbox'] for i in range(start_idx, end_idx + 1)] x0 = min(bbox[0] for bbox in entity_bboxes) y0 = min(bbox[1] for bbox in entity_bboxes) x1 = max(bbox[2] for bbox in entity_bboxes) y1 = max(bbox[3] for bbox in entity_bboxes) all_words_on_page = [r['word'] for r in page_results] start_char = len(" ".join(all_words_on_page[:start_idx])) if start_idx != 0: start_char += 1 end_char = start_char + len(" ".join(entity_words)) span_text = " ".join(entity_words) return { "from_name": "label", "to_name": "text", "type": "labels", "value": { "start": start_char, "end": end_char, "text": span_text, "labels": [label], "bbox": {"x": x0, "y": y0, "width": x1 - x0, "height": y1 - y0} }, "score": 0.99 } def convert_raw_predictions_to_label_studio(page_data_list, output_path: str): final_tasks = [] print("\n[PHASE: LABEL STUDIO CONVERSION]") for page_data in page_data_list: page_num = page_data['page_number'] page_results = page_data['data'] if not page_results: continue original_words = [r['word'] for r in page_results] text_string = " ".join(original_words) results = [] current_entity_label = None current_entity_start_word_index = None for i, pred_item in enumerate(page_results): label = pred_item['predicted_label'] tag_only = label.split('-', 1)[-1] if '-' in label else label if label.startswith('B-'): if current_entity_label: results.append(create_label_studio_span(page_results, current_entity_start_word_index, i - 1, current_entity_label)) current_entity_label = tag_only current_entity_start_word_index = i elif label.startswith('I-') and current_entity_label == tag_only: continue else: if current_entity_label: results.append(create_label_studio_span(page_results, current_entity_start_word_index, i - 1, current_entity_label)) current_entity_label = None current_entity_start_word_index = None if current_entity_label: results.append(create_label_studio_span(page_results, current_entity_start_word_index, len(page_results) - 1, current_entity_label)) final_tasks.append({ "data": { "text": text_string, "original_words": original_words, "original_bboxes": [r['bbox'] for r in page_results] }, "annotations": [{"result": results}], "meta": {"page_number": page_num} }) with open(output_path, "w", encoding='utf-8') as f: json.dump(final_tasks, f, indent=2, ensure_ascii=False) print(f"\n✅ Label Studio tasks saved to {output_path}.") # ============================================================================ # --- PHASE 3: BIO TO STRUCTURED JSON DECODER --- # ============================================================================ def convert_bio_to_structured_json_relaxed(input_path: str, output_path: str) -> Optional[List[Dict[str, Any]]]: print("\n" + "=" * 80) print("--- 3. STARTING BIO TO STRUCTURED JSON DECODING ---") print("=" * 80) try: with open(input_path, 'r', encoding='utf-8') as f: predictions_by_page = json.load(f) except Exception as e: print(f"❌ Error loading raw prediction file: {e}") return None predictions = [] for page_item in predictions_by_page: if isinstance(page_item, dict) and 'data' in page_item: predictions.extend(page_item['data']) structured_data = [] current_item = None current_option_key = None current_passage_buffer = [] current_text_buffer = [] first_question_started = False last_entity_type = None just_finished_i_option = False is_in_new_passage = False def finalize_passage_to_item(item, passage_buffer): if passage_buffer: passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip() if item.get('passage'): item['passage'] += ' ' + passage_text else: item['passage'] = passage_text passage_buffer.clear() for item in predictions: word = item['word'] label = item['predicted_label'] entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None current_text_buffer.append(word) previous_entity_type = last_entity_type is_passage_label = (entity_type == 'PASSAGE') if not first_question_started: if label != 'B-QUESTION' and not is_passage_label: just_finished_i_option = False is_in_new_passage = False continue if is_passage_label: current_passage_buffer.append(word) last_entity_type = 'PASSAGE' just_finished_i_option = False is_in_new_passage = False continue if label == 'B-QUESTION': if not first_question_started: header_text = ' '.join(current_text_buffer[:-1]).strip() if header_text or current_passage_buffer: metadata_item = {'type': 'METADATA', 'passage': ''} finalize_passage_to_item(metadata_item, current_passage_buffer) if header_text: metadata_item['text'] = header_text structured_data.append(metadata_item) first_question_started = True current_text_buffer = [word] if current_item is not None: finalize_passage_to_item(current_item, current_passage_buffer) current_item['text'] = ' '.join(current_text_buffer[:-1]).strip() structured_data.append(current_item) current_text_buffer = [word] current_item = { 'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': '' } current_option_key = None last_entity_type = 'QUESTION' just_finished_i_option = False is_in_new_passage = False continue if current_item is not None: if is_in_new_passage: # 🔑 Robust Initialization and Appending for 'new_passage' if 'new_passage' not in current_item: current_item['new_passage'] = word else: current_item['new_passage'] += f' {word}' if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'): is_in_new_passage = False if label.startswith(('B-', 'I-')): last_entity_type = entity_type continue is_in_new_passage = False if label.startswith('B-'): if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']: finalize_passage_to_item(current_item, current_passage_buffer) current_passage_buffer = [] last_entity_type = entity_type if entity_type == 'PASSAGE': if previous_entity_type == 'OPTION' and just_finished_i_option: current_item['new_passage'] = word # Initialize the new passage start is_in_new_passage = True else: current_passage_buffer.append(word) elif entity_type == 'OPTION': current_option_key = word current_item['options'][current_option_key] = word just_finished_i_option = False elif entity_type == 'ANSWER': current_item['answer'] = word current_option_key = None just_finished_i_option = False elif entity_type == 'QUESTION': current_item['question'] += f' {word}' just_finished_i_option = False elif label.startswith('I-'): if entity_type == 'QUESTION': current_item['question'] += f' {word}' elif entity_type == 'PASSAGE': if previous_entity_type == 'OPTION' and just_finished_i_option: current_item['new_passage'] = word # Initialize the new passage start is_in_new_passage = True else: if not current_passage_buffer: last_entity_type = 'PASSAGE' current_passage_buffer.append(word) elif entity_type == 'OPTION' and current_option_key is not None: current_item['options'][current_option_key] += f' {word}' just_finished_i_option = True elif entity_type == 'ANSWER': current_item['answer'] += f' {word}' just_finished_i_option = (entity_type == 'OPTION') elif label == 'O': if last_entity_type == 'QUESTION': current_item['question'] += f' {word}' just_finished_i_option = False if current_item is not None: finalize_passage_to_item(current_item, current_passage_buffer) current_item['text'] = ' '.join(current_text_buffer).strip() structured_data.append(current_item) for item in structured_data: item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip() if 'new_passage' in item: item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip() try: with open(output_path, 'w', encoding='utf-8') as f: json.dump(structured_data, f, indent=2, ensure_ascii=False) except Exception: pass return structured_data def create_query_text(entry: Dict[str, Any]) -> str: """Combines question and options into a single string for similarity matching.""" query_parts = [] if entry.get("question"): query_parts.append(entry["question"]) for key in ["options", "options_text"]: options = entry.get(key) if options and isinstance(options, dict): for value in options.values(): if value and isinstance(value, str): query_parts.append(value) return " ".join(query_parts) def calculate_similarity(doc1: str, doc2: str) -> float: """Calculates Cosine Similarity between two text strings.""" if not doc1 or not doc2: return 0.0 def clean_text(text): return re.sub(r'^\s*[\(\d\w]+\.?\s*', '', text, flags=re.MULTILINE) clean_doc1 = clean_text(doc1) clean_doc2 = clean_text(doc2) corpus = [clean_doc1, clean_doc2] try: vectorizer = CountVectorizer(stop_words='english', lowercase=True, token_pattern=r'(?u)\b\w\w+\b') tfidf_matrix = vectorizer.fit_transform(corpus) if tfidf_matrix.shape[1] == 0: return 0.0 vectors = tfidf_matrix.toarray() # Handle cases where vectors might be empty or too short if len(vectors) < 2: return 0.0 score = cosine_similarity(vectors[0:1], vectors[1:2])[0][0] return score except Exception: return 0.0 def process_context_linking(data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Links questions to passages based on 'passage' flow vs 'new_passage' priority. Includes 'Decay Logic': If 2 consecutive questions fail to match the active passage, the passage context is dropped to prevent false positives downstream. """ print("\n" + "=" * 80) print("--- STARTING CONTEXT LINKING (WITH DECAY LOGIC) ---") print("=" * 80) if not data: return [] # --- PHASE 1: IDENTIFY PASSAGE DEFINERS --- passage_definer_indices = [] for i, entry in enumerate(data): if entry.get("passage") and entry["passage"].strip(): passage_definer_indices.append(i) if entry.get("new_passage") and entry["new_passage"].strip(): if i not in passage_definer_indices: passage_definer_indices.append(i) # --- PHASE 2: CONTEXT TRANSFER & LINKING --- current_passage_text = None current_new_passage_text = None # NEW: Counter to track consecutive linking failures consecutive_failures = 0 MAX_CONSECUTIVE_FAILURES = 2 for i, entry in enumerate(data): item_type = entry.get("type", "Question") # A. UNCONDITIONALLY UPDATE CONTEXTS (And Reset Decay Counter) if entry.get("passage") and entry["passage"].strip(): current_passage_text = entry["passage"] consecutive_failures = 0 # Reset because we have fresh explicit context # print(f" [Flow] Updated Standard Context from Item {i}") if entry.get("new_passage") and entry["new_passage"].strip(): current_new_passage_text = entry["new_passage"] # We don't necessarily reset standard failures here as this is a local override # B. QUESTION LINKING if entry.get("question") and item_type != "METADATA": combined_query = create_query_text(entry) # Skip if query is too short (noise) if len(combined_query.strip()) < 5: continue # Calculate scores score_old = calculate_similarity(current_passage_text, combined_query) if current_passage_text else 0.0 score_new = calculate_similarity(current_new_passage_text, combined_query) if current_new_passage_text else 0.0 q_preview = entry['question'][:30] + '...' # RESOLUTION LOGIC linked = False # 1. Prefer New Passage if significantly better if current_new_passage_text and (score_new > score_old + RESOLUTION_MARGIN) and (score_new >= SIMILARITY_THRESHOLD): entry["passage"] = current_new_passage_text print(f" [Linker] 🚀 Q{i} ('{q_preview}') -> NEW PASSAGE (Score: {score_new:.3f})") linked = True # Note: We do not reset 'consecutive_failures' for the standard passage here, # because we matched the *new* passage, not the standard one. # 2. Otherwise use Standard Passage if it meets threshold elif current_passage_text and (score_old >= SIMILARITY_THRESHOLD): entry["passage"] = current_passage_text print(f" [Linker] ✅ Q{i} ('{q_preview}') -> STANDARD PASSAGE (Score: {score_old:.3f})") linked = True consecutive_failures = 0 # Success! Reset the kill switch. if not linked: # 3. DECAY LOGIC if current_passage_text: consecutive_failures += 1 print(f" [Linker] ⚠️ Q{i} NOT LINKED. (Failures: {consecutive_failures}/{MAX_CONSECUTIVE_FAILURES})") if consecutive_failures >= MAX_CONSECUTIVE_FAILURES: print(f" [Linker] 🗑️ Context dropped due to {consecutive_failures} consecutive misses.") current_passage_text = None consecutive_failures = 0 else: print(f" [Linker] ⚠️ Q{i} NOT LINKED (No active context).") # --- PHASE 3: CLEANUP AND INTERPOLATION --- print(" [Linker] Running Cleanup & Interpolation...") # 3A. Self-Correction (Remove weak links) for i in passage_definer_indices: entry = data[i] if entry.get("question") and entry.get("type") != "METADATA": passage_to_check = entry.get("passage") or entry.get("new_passage") if passage_to_check: self_sim = calculate_similarity(passage_to_check, create_query_text(entry)) if self_sim < SIMILARITY_THRESHOLD: entry["passage"] = "" if "new_passage" in entry: entry["new_passage"] = "" print(f" [Cleanup] Removed weak link for Q{i}") # 3B. Interpolation (Fill gaps) # We only interpolate if the gap is strictly 1 question wide to avoid undoing the decay logic for i in range(1, len(data) - 1): current_entry = data[i] is_gap = current_entry.get("question") and not current_entry.get("passage") if is_gap: prev_p = data[i - 1].get("passage") next_p = data[i + 1].get("passage") if prev_p and next_p and (prev_p == next_p) and prev_p.strip(): current_entry["passage"] = prev_p print(f" [Linker] 🥪 Q{i} Interpolated from neighbors.") return data def correct_misaligned_options(structured_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: print("\n" + "=" * 80) print("--- 5. STARTING POST-PROCESSING: OPTION ALIGNMENT CORRECTION ---") print("=" * 80) tag_pattern = re.compile(r'(EQUATION\d+|FIGURE\d+)') corrected_count = 0 for item in structured_data: if item.get('type') in ['METADATA']: continue options = item.get('options') if not options or len(options) < 2: continue option_keys = list(options.keys()) for i in range(len(option_keys) - 1): current_key = option_keys[i] next_key = option_keys[i + 1] current_value = options[current_key].strip() next_value = options[next_key].strip() is_current_empty = current_value == current_key content_in_next = next_value.replace(next_key, '', 1).strip() tags_in_next = tag_pattern.findall(content_in_next) has_two_tags = len(tags_in_next) == 2 if is_current_empty and has_two_tags: tag_to_move = tags_in_next[0] options[current_key] = f"{current_key} {tag_to_move}".strip() options[next_key] = f"{next_key} {tags_in_next[1]}".strip() corrected_count += 1 print(f"✅ Option alignment correction finished. Total corrections: {corrected_count}.") return structured_data # ============================================================================ # --- PHASE 4: IMAGE EMBEDDING (Base64) --- # ============================================================================ def get_base64_for_file(filepath: str) -> str: try: with open(filepath, 'rb') as f: return base64.b64encode(f.read()).decode('utf-8') except Exception as e: print(f" ❌ Error encoding file {filepath}: {e}") return "" def embed_images_as_base64_in_memory(structured_data: List[Dict[str, Any]], figure_extraction_dir: str) -> List[Dict[str, Any]]: print("\n" + "=" * 80) print("--- 4. STARTING IMAGE EMBEDDING (Base64) ---") print("=" * 80) if not structured_data: return [] image_files = glob.glob(os.path.join(figure_extraction_dir, "*.png")) image_lookup = {} tag_regex = re.compile(r'(figure|equation)(\d+)', re.IGNORECASE) for filepath in image_files: filename = os.path.basename(filepath) match = re.search(r'_(figure|equation)(\d+)\.png$', filename, re.IGNORECASE) if match: key = f"{match.group(1).upper()}{match.group(2)}" image_lookup[key] = filepath print(f" -> Found {len(image_lookup)} image components.") final_structured_data = [] for item in structured_data: text_fields = [item.get('question', ''), item.get('passage', '')] if 'options' in item: for opt_val in item['options'].values(): text_fields.append(opt_val) if 'new_passage' in item: text_fields.append(item['new_passage']) unique_tags_to_embed = set() for text in text_fields: if not text: continue for match in tag_regex.finditer(text): tag = match.group(0).upper() if tag in image_lookup: unique_tags_to_embed.add(tag) for tag in sorted(list(unique_tags_to_embed)): filepath = image_lookup[tag] base64_code = get_base64_for_file(filepath) base_key = tag.replace(' ', '').lower() item[base_key] = base64_code final_structured_data.append(item) print(f"✅ Image embedding complete.") return final_structured_data # ============================================================================ # --- MAIN FUNCTION --- # ============================================================================ def run_document_pipeline(input_pdf_path: str, layoutlmv3_model_path: str, label_studio_output_path: str) -> Optional[List[Dict[str, Any]]]: if not os.path.exists(input_pdf_path): return None print("\n" + "#" * 80) print("### STARTING OPTIMIZED FULL DOCUMENT ANALYSIS PIPELINE ###") print("#" * 80) pdf_name = os.path.splitext(os.path.basename(input_pdf_path))[0] temp_pipeline_dir = os.path.join(tempfile.gettempdir(), f"pipeline_run_{pdf_name}_{os.getpid()}") os.makedirs(temp_pipeline_dir, exist_ok=True) preprocessed_json_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_preprocessed.json") raw_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_raw_predictions.json") structured_intermediate_output_path = os.path.join(temp_pipeline_dir, f"{pdf_name}_structured_intermediate.json") final_result = None try: # Phase 1: Preprocessing with YOLO First + Masking preprocessed_json_path_out = run_single_pdf_preprocessing(input_pdf_path, preprocessed_json_path) if not preprocessed_json_path_out: return None # Phase 2: Inference page_raw_predictions_list = run_inference_and_get_raw_words( input_pdf_path, layoutlmv3_model_path, preprocessed_json_path_out ) if not page_raw_predictions_list: return None with open(raw_output_path, 'w', encoding='utf-8') as f: json.dump(page_raw_predictions_list, f, indent=4) # Phase 3: Decoding structured_data_list = convert_bio_to_structured_json_relaxed( raw_output_path, structured_intermediate_output_path ) if not structured_data_list: return None structured_data_list = correct_misaligned_options(structured_data_list) structured_data_list = process_context_linking(structured_data_list) try: convert_raw_predictions_to_label_studio(page_raw_predictions_list, label_studio_output_path) except Exception as e: print(f"❌ Error during Label Studio conversion: {e}") # Phase 4: Embedding final_result = embed_images_as_base64_in_memory(structured_data_list, FIGURE_EXTRACTION_DIR) except Exception as e: print(f"❌ FATAL ERROR: {e}") import traceback traceback.print_exc() return None finally: try: for f in glob.glob(os.path.join(temp_pipeline_dir, '*')): os.remove(f) os.rmdir(temp_pipeline_dir) except Exception: pass print("\n" + "#" * 80) print("### OPTIMIZED PIPELINE EXECUTION COMPLETE ###") print("#" * 80) return final_result if __name__ == "__main__": parser = argparse.ArgumentParser(description="Complete Pipeline") parser.add_argument("--input_pdf", type=str, required=True, help="Input PDF") parser.add_argument("--layoutlmv3_model_path", type=str, default=DEFAULT_LAYOUTLMV3_MODEL_PATH, help="Model Path") parser.add_argument("--ls_output_path", type=str, default=None, help="Label Studio Output Path") args = parser.parse_args() pdf_name = os.path.splitext(os.path.basename(args.input_pdf))[0] final_output_path = os.path.abspath(f"{pdf_name}_final_output_embedded.json") ls_output_path = os.path.abspath(args.ls_output_path if args.ls_output_path else f"{pdf_name}_label_studio_tasks.json") final_json_data = run_document_pipeline(args.input_pdf, args.layoutlmv3_model_path, ls_output_path) if final_json_data: with open(final_output_path, 'w', encoding='utf-8') as f: json.dump(final_json_data, f, indent=2, ensure_ascii=False) print(f"\n✅ Final Data Saved: {final_output_path}") else: print("\n❌ Pipeline Failed.") sys.exit(1)