import gradio as gr import torch import torch.nn as nn import pdfplumber import json import os import re from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model from TorchCRF import CRF # --- Configuration --- # Ensure this filename matches exactly what you uploaded to the Space MODEL_FILENAME = "layoutlmv3_nonlinear_scratch.pth" BASE_MODEL_ID = "microsoft/layoutlmv3-base" LABELS = ["O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION", "B-ANSWER", "I-ANSWER", "B-SECTION_HEADING", "I-SECTION_HEADING", "B-PASSAGE", "I-PASSAGE"] LABEL2ID = {l: i for i, l in enumerate(LABELS)} ID2LABEL = {i: l for l, i in LABEL2ID.items()} # --------------------------------------------------------- # 1. MODEL ARCHITECTURE # --------------------------------------------------------- class LayoutLMv3CRF(nn.Module): def __init__(self, num_labels): super().__init__() self.layoutlm = LayoutLMv3Model.from_pretrained(BASE_MODEL_ID) hidden_size = self.layoutlm.config.hidden_size self.classifier = nn.Sequential( nn.Linear(hidden_size, hidden_size), nn.GELU(), nn.LayerNorm(hidden_size), nn.Dropout(0.1), nn.Linear(hidden_size, num_labels) ) self.crf = CRF(num_labels) def forward(self, input_ids, bbox, attention_mask, labels=None): outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask) sequence_output = outputs.last_hidden_state emissions = self.classifier(sequence_output) if labels is not None: log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool()) return -log_likelihood.mean() else: return self.crf.viterbi_decode(emissions, mask=attention_mask.bool()) # --------------------------------------------------------- # 2. MODEL LOADING # --------------------------------------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = LayoutLMv3TokenizerFast.from_pretrained(BASE_MODEL_ID) model = None def load_model(): global model if model is None: print(f"🔄 Loading model from {MODEL_FILENAME}...") if not os.path.exists(MODEL_FILENAME): raise FileNotFoundError(f"Model file {MODEL_FILENAME} not found. Please upload it to the Space.") model = LayoutLMv3CRF(num_labels=len(LABELS)) state_dict = torch.load(MODEL_FILENAME, map_location=device) model.load_state_dict(state_dict) model.to(device) model.eval() print("✅ Model loaded successfully.") return model # --------------------------------------------------------- # 3. CONVERSION LOGIC (Your Custom Function) # --------------------------------------------------------- def convert_bio_to_structured_json(predictions): structured_data = [] current_item = None current_option_key = None current_passage_buffer = [] current_text_buffer = [] first_question_started = False last_entity_type = None just_finished_i_option = False is_in_new_passage = False def finalize_passage_to_item(item, passage_buffer): if passage_buffer: passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip() if item.get('passage'): item['passage'] += ' ' + passage_text else: item['passage'] = passage_text passage_buffer.clear() # Flatten predictions list flat_predictions = [] for page in predictions: flat_predictions.extend(page['data']) for idx, item in enumerate(flat_predictions): word = item['word'] label = item['predicted_label'] entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None current_text_buffer.append(word) previous_entity_type = last_entity_type is_passage_label = (entity_type == 'PASSAGE') if not first_question_started: if label != 'B-QUESTION' and not is_passage_label: just_finished_i_option = False is_in_new_passage = False continue if is_passage_label: current_passage_buffer.append(word) last_entity_type = 'PASSAGE' just_finished_i_option = False is_in_new_passage = False continue if label == 'B-QUESTION': if not first_question_started: header_text = ' '.join(current_text_buffer[:-1]).strip() if header_text or current_passage_buffer: metadata_item = {'type': 'METADATA', 'passage': ''} finalize_passage_to_item(metadata_item, current_passage_buffer) if header_text: metadata_item['text'] = header_text structured_data.append(metadata_item) first_question_started = True current_text_buffer = [word] if current_item is not None: finalize_passage_to_item(current_item, current_passage_buffer) current_item['text'] = ' '.join(current_text_buffer[:-1]).strip() structured_data.append(current_item) current_text_buffer = [word] current_item = { 'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': '' } current_option_key = None last_entity_type = 'QUESTION' just_finished_i_option = False is_in_new_passage = False continue if current_item is not None: if is_in_new_passage: if 'new_passage' not in current_item: current_item['new_passage'] = word else: current_item['new_passage'] += f' {word}' if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'): is_in_new_passage = False if label.startswith(('B-', 'I-')): last_entity_type = entity_type continue is_in_new_passage = False if label.startswith('B-'): if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']: finalize_passage_to_item(current_item, current_passage_buffer) current_passage_buffer = [] last_entity_type = entity_type if entity_type == 'PASSAGE': if previous_entity_type == 'OPTION' and just_finished_i_option: current_item['new_passage'] = word is_in_new_passage = True else: current_passage_buffer.append(word) elif entity_type == 'OPTION': current_option_key = word current_item['options'][current_option_key] = word just_finished_i_option = False elif entity_type == 'ANSWER': current_item['answer'] = word current_option_key = None just_finished_i_option = False elif entity_type == 'QUESTION': current_item['question'] += f' {word}' just_finished_i_option = False elif label.startswith('I-'): if entity_type == 'QUESTION': current_item['question'] += f' {word}' elif entity_type == 'PASSAGE': if previous_entity_type == 'OPTION' and just_finished_i_option: current_item['new_passage'] = word is_in_new_passage = True else: if not current_passage_buffer: last_entity_type = 'PASSAGE' current_passage_buffer.append(word) elif entity_type == 'OPTION' and current_option_key is not None: current_item['options'][current_option_key] += f' {word}' just_finished_i_option = True elif entity_type == 'ANSWER': current_item['answer'] += f' {word}' just_finished_i_option = (entity_type == 'OPTION') if current_item is not None: finalize_passage_to_item(current_item, current_passage_buffer) current_item['text'] = ' '.join(current_text_buffer).strip() structured_data.append(current_item) # Clean text for item in structured_data: if 'text' in item: item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip() if 'new_passage' in item: item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip() return structured_data # --------------------------------------------------------- # 4. PROCESSING PIPELINE # --------------------------------------------------------- def process_pdf(pdf_file): if pdf_file is None: return None, "Please upload a PDF file." try: model = load_model() # 1. Extract extracted_pages = [] with pdfplumber.open(pdf_file.name) as pdf: for page_idx, page in enumerate(pdf.pages): width, height = page.width, page.height words_data = page.extract_words() page_tokens = [] page_bboxes = [] for w in words_data: text = w['text'] x0 = int((w['x0'] / width) * 1000) top = int((w['top'] / height) * 1000) x1 = int((w['x1'] / width) * 1000) bottom = int((w['bottom'] / height) * 1000) box = [max(0, min(x0, 1000)), max(0, min(top, 1000)), max(0, min(x1, 1000)), max(0, min(bottom, 1000))] page_tokens.append(text) page_bboxes.append(box) extracted_pages.append({"page_id": page_idx, "tokens": page_tokens, "bboxes": page_bboxes}) # 2. Inference raw_predictions = [] for page in extracted_pages: tokens = page['tokens'] bboxes = page['bboxes'] if not tokens: continue encoding = tokenizer(tokens, boxes=bboxes, return_tensors="pt", padding="max_length", truncation=True, max_length=512, return_offsets_mapping=True) input_ids = encoding.input_ids.to(device) bbox = encoding.bbox.to(device) attention_mask = encoding.attention_mask.to(device) with torch.no_grad(): preds = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask) pred_tags = preds[0] word_ids = encoding.word_ids() aligned_data = [] prev_word_idx = None for i, word_idx in enumerate(word_ids): if word_idx is None: continue if word_idx != prev_word_idx: label_str = ID2LABEL[pred_tags[i]] aligned_data.append({"word": tokens[word_idx], "predicted_label": label_str}) prev_word_idx = word_idx raw_predictions.append({"data": aligned_data}) # 3. Structure final_json = convert_bio_to_structured_json(raw_predictions) # Save to file for download output_filename = "structured_output.json" with open(output_filename, "w", encoding="utf-8") as f: json.dump(final_json, f, indent=2, ensure_ascii=False) return output_filename, f"✅ Successfully processed {len(extracted_pages)} pages. Found {len(final_json)} structured items." except Exception as e: return None, f"❌ Error: {str(e)}" # --------------------------------------------------------- # 5. GRADIO INTERFACE # --------------------------------------------------------- # iface = gr.Interface( # fn=process_pdf, # inputs=gr.File(label="Upload PDF", file_types=[".pdf"]), # outputs=[ # gr.File(label="Download JSON Output"), # gr.Textbox(label="Status Log") # ], # title="LayoutLMv3 PDF Parser", # description="Upload a document to extract Questions, Options, and Passages into structured JSON.", # allow_flagging="never" # ) # if __name__ == "__main__": # iface.launch() # --------------------------------------------------------- # 5. GRADIO INTERFACE # --------------------------------------------------------- iface = gr.Interface( fn=process_pdf, inputs=gr.File(label="Upload PDF", file_types=[".pdf"]), outputs=[ gr.File(label="Download JSON Output"), gr.Textbox(label="Status Log") ], title="LayoutLMv3 PDF Parser", description="Upload a document to extract Questions, Options, and Passages into structured JSON.", flagging_mode="never" # <--- This is the fix (renamed from allow_flagging) ) if __name__ == "__main__": iface.launch()