Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import pdfplumber | |
| import json | |
| import os | |
| import re | |
| from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Model | |
| from TorchCRF import CRF | |
| # --- Configuration --- | |
| # Ensure this filename matches exactly what you uploaded to the Space | |
| MODEL_FILENAME = "layoutlmv3_nonlinear_scratch.pth" | |
| BASE_MODEL_ID = "microsoft/layoutlmv3-base" | |
| LABELS = ["O", "B-QUESTION", "I-QUESTION", "B-OPTION", "I-OPTION", "B-ANSWER", "I-ANSWER", "B-SECTION_HEADING", "I-SECTION_HEADING", "B-PASSAGE", "I-PASSAGE"] | |
| LABEL2ID = {l: i for i, l in enumerate(LABELS)} | |
| ID2LABEL = {i: l for l, i in LABEL2ID.items()} | |
| # --------------------------------------------------------- | |
| # 1. MODEL ARCHITECTURE | |
| # --------------------------------------------------------- | |
| class LayoutLMv3CRF(nn.Module): | |
| def __init__(self, num_labels): | |
| super().__init__() | |
| self.layoutlm = LayoutLMv3Model.from_pretrained(BASE_MODEL_ID) | |
| hidden_size = self.layoutlm.config.hidden_size | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size), | |
| nn.GELU(), | |
| nn.LayerNorm(hidden_size), | |
| nn.Dropout(0.1), | |
| nn.Linear(hidden_size, num_labels) | |
| ) | |
| self.crf = CRF(num_labels) | |
| def forward(self, input_ids, bbox, attention_mask, labels=None): | |
| outputs = self.layoutlm(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask) | |
| sequence_output = outputs.last_hidden_state | |
| emissions = self.classifier(sequence_output) | |
| if labels is not None: | |
| log_likelihood = self.crf(emissions, labels, mask=attention_mask.bool()) | |
| return -log_likelihood.mean() | |
| else: | |
| return self.crf.viterbi_decode(emissions, mask=attention_mask.bool()) | |
| # --------------------------------------------------------- | |
| # 2. MODEL LOADING | |
| # --------------------------------------------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| tokenizer = LayoutLMv3TokenizerFast.from_pretrained(BASE_MODEL_ID) | |
| model = None | |
| def load_model(): | |
| global model | |
| if model is None: | |
| print(f"🔄 Loading model from {MODEL_FILENAME}...") | |
| if not os.path.exists(MODEL_FILENAME): | |
| raise FileNotFoundError(f"Model file {MODEL_FILENAME} not found. Please upload it to the Space.") | |
| model = LayoutLMv3CRF(num_labels=len(LABELS)) | |
| state_dict = torch.load(MODEL_FILENAME, map_location=device) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| print("✅ Model loaded successfully.") | |
| return model | |
| # --------------------------------------------------------- | |
| # 3. CONVERSION LOGIC (Your Custom Function) | |
| # --------------------------------------------------------- | |
| def convert_bio_to_structured_json(predictions): | |
| structured_data = [] | |
| current_item = None | |
| current_option_key = None | |
| current_passage_buffer = [] | |
| current_text_buffer = [] | |
| first_question_started = False | |
| last_entity_type = None | |
| just_finished_i_option = False | |
| is_in_new_passage = False | |
| def finalize_passage_to_item(item, passage_buffer): | |
| if passage_buffer: | |
| passage_text = re.sub(r'\s{2,}', ' ', ' '.join(passage_buffer)).strip() | |
| if item.get('passage'): item['passage'] += ' ' + passage_text | |
| else: item['passage'] = passage_text | |
| passage_buffer.clear() | |
| # Flatten predictions list | |
| flat_predictions = [] | |
| for page in predictions: | |
| flat_predictions.extend(page['data']) | |
| for idx, item in enumerate(flat_predictions): | |
| word = item['word'] | |
| label = item['predicted_label'] | |
| entity_type = label[2:].strip() if label.startswith(('B-', 'I-')) else None | |
| current_text_buffer.append(word) | |
| previous_entity_type = last_entity_type | |
| is_passage_label = (entity_type == 'PASSAGE') | |
| if not first_question_started: | |
| if label != 'B-QUESTION' and not is_passage_label: | |
| just_finished_i_option = False | |
| is_in_new_passage = False | |
| continue | |
| if is_passage_label: | |
| current_passage_buffer.append(word) | |
| last_entity_type = 'PASSAGE' | |
| just_finished_i_option = False | |
| is_in_new_passage = False | |
| continue | |
| if label == 'B-QUESTION': | |
| if not first_question_started: | |
| header_text = ' '.join(current_text_buffer[:-1]).strip() | |
| if header_text or current_passage_buffer: | |
| metadata_item = {'type': 'METADATA', 'passage': ''} | |
| finalize_passage_to_item(metadata_item, current_passage_buffer) | |
| if header_text: metadata_item['text'] = header_text | |
| structured_data.append(metadata_item) | |
| first_question_started = True | |
| current_text_buffer = [word] | |
| if current_item is not None: | |
| finalize_passage_to_item(current_item, current_passage_buffer) | |
| current_item['text'] = ' '.join(current_text_buffer[:-1]).strip() | |
| structured_data.append(current_item) | |
| current_text_buffer = [word] | |
| current_item = { | |
| 'question': word, 'options': {}, 'answer': '', 'passage': '', 'text': '' | |
| } | |
| current_option_key = None | |
| last_entity_type = 'QUESTION' | |
| just_finished_i_option = False | |
| is_in_new_passage = False | |
| continue | |
| if current_item is not None: | |
| if is_in_new_passage: | |
| if 'new_passage' not in current_item: current_item['new_passage'] = word | |
| else: current_item['new_passage'] += f' {word}' | |
| if label.startswith('B-') or (label.startswith('I-') and entity_type != 'PASSAGE'): | |
| is_in_new_passage = False | |
| if label.startswith(('B-', 'I-')): last_entity_type = entity_type | |
| continue | |
| is_in_new_passage = False | |
| if label.startswith('B-'): | |
| if entity_type in ['QUESTION', 'OPTION', 'ANSWER', 'SECTION_HEADING']: | |
| finalize_passage_to_item(current_item, current_passage_buffer) | |
| current_passage_buffer = [] | |
| last_entity_type = entity_type | |
| if entity_type == 'PASSAGE': | |
| if previous_entity_type == 'OPTION' and just_finished_i_option: | |
| current_item['new_passage'] = word | |
| is_in_new_passage = True | |
| else: current_passage_buffer.append(word) | |
| elif entity_type == 'OPTION': | |
| current_option_key = word | |
| current_item['options'][current_option_key] = word | |
| just_finished_i_option = False | |
| elif entity_type == 'ANSWER': | |
| current_item['answer'] = word | |
| current_option_key = None | |
| just_finished_i_option = False | |
| elif entity_type == 'QUESTION': | |
| current_item['question'] += f' {word}' | |
| just_finished_i_option = False | |
| elif label.startswith('I-'): | |
| if entity_type == 'QUESTION': current_item['question'] += f' {word}' | |
| elif entity_type == 'PASSAGE': | |
| if previous_entity_type == 'OPTION' and just_finished_i_option: | |
| current_item['new_passage'] = word | |
| is_in_new_passage = True | |
| else: | |
| if not current_passage_buffer: last_entity_type = 'PASSAGE' | |
| current_passage_buffer.append(word) | |
| elif entity_type == 'OPTION' and current_option_key is not None: | |
| current_item['options'][current_option_key] += f' {word}' | |
| just_finished_i_option = True | |
| elif entity_type == 'ANSWER': current_item['answer'] += f' {word}' | |
| just_finished_i_option = (entity_type == 'OPTION') | |
| if current_item is not None: | |
| finalize_passage_to_item(current_item, current_passage_buffer) | |
| current_item['text'] = ' '.join(current_text_buffer).strip() | |
| structured_data.append(current_item) | |
| # Clean text | |
| for item in structured_data: | |
| if 'text' in item: item['text'] = re.sub(r'\s{2,}', ' ', item['text']).strip() | |
| if 'new_passage' in item: item['new_passage'] = re.sub(r'\s{2,}', ' ', item['new_passage']).strip() | |
| return structured_data | |
| # --------------------------------------------------------- | |
| # 4. PROCESSING PIPELINE | |
| # --------------------------------------------------------- | |
| def process_pdf(pdf_file): | |
| if pdf_file is None: | |
| return None, "Please upload a PDF file." | |
| try: | |
| model = load_model() | |
| # 1. Extract | |
| extracted_pages = [] | |
| with pdfplumber.open(pdf_file.name) as pdf: | |
| for page_idx, page in enumerate(pdf.pages): | |
| width, height = page.width, page.height | |
| words_data = page.extract_words() | |
| page_tokens = [] | |
| page_bboxes = [] | |
| for w in words_data: | |
| text = w['text'] | |
| x0 = int((w['x0'] / width) * 1000) | |
| top = int((w['top'] / height) * 1000) | |
| x1 = int((w['x1'] / width) * 1000) | |
| bottom = int((w['bottom'] / height) * 1000) | |
| box = [max(0, min(x0, 1000)), max(0, min(top, 1000)), | |
| max(0, min(x1, 1000)), max(0, min(bottom, 1000))] | |
| page_tokens.append(text) | |
| page_bboxes.append(box) | |
| extracted_pages.append({"page_id": page_idx, "tokens": page_tokens, "bboxes": page_bboxes}) | |
| # 2. Inference | |
| raw_predictions = [] | |
| for page in extracted_pages: | |
| tokens = page['tokens'] | |
| bboxes = page['bboxes'] | |
| if not tokens: continue | |
| encoding = tokenizer(tokens, boxes=bboxes, return_tensors="pt", | |
| padding="max_length", truncation=True, max_length=512, | |
| return_offsets_mapping=True) | |
| input_ids = encoding.input_ids.to(device) | |
| bbox = encoding.bbox.to(device) | |
| attention_mask = encoding.attention_mask.to(device) | |
| with torch.no_grad(): | |
| preds = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask) | |
| pred_tags = preds[0] | |
| word_ids = encoding.word_ids() | |
| aligned_data = [] | |
| prev_word_idx = None | |
| for i, word_idx in enumerate(word_ids): | |
| if word_idx is None: continue | |
| if word_idx != prev_word_idx: | |
| label_str = ID2LABEL[pred_tags[i]] | |
| aligned_data.append({"word": tokens[word_idx], "predicted_label": label_str}) | |
| prev_word_idx = word_idx | |
| raw_predictions.append({"data": aligned_data}) | |
| # 3. Structure | |
| final_json = convert_bio_to_structured_json(raw_predictions) | |
| # Save to file for download | |
| output_filename = "structured_output.json" | |
| with open(output_filename, "w", encoding="utf-8") as f: | |
| json.dump(final_json, f, indent=2, ensure_ascii=False) | |
| return output_filename, f"✅ Successfully processed {len(extracted_pages)} pages. Found {len(final_json)} structured items." | |
| except Exception as e: | |
| return None, f"❌ Error: {str(e)}" | |
| # --------------------------------------------------------- | |
| # 5. GRADIO INTERFACE | |
| # --------------------------------------------------------- | |
| # iface = gr.Interface( | |
| # fn=process_pdf, | |
| # inputs=gr.File(label="Upload PDF", file_types=[".pdf"]), | |
| # outputs=[ | |
| # gr.File(label="Download JSON Output"), | |
| # gr.Textbox(label="Status Log") | |
| # ], | |
| # title="LayoutLMv3 PDF Parser", | |
| # description="Upload a document to extract Questions, Options, and Passages into structured JSON.", | |
| # allow_flagging="never" | |
| # ) | |
| # if __name__ == "__main__": | |
| # iface.launch() | |
| # --------------------------------------------------------- | |
| # 5. GRADIO INTERFACE | |
| # --------------------------------------------------------- | |
| iface = gr.Interface( | |
| fn=process_pdf, | |
| inputs=gr.File(label="Upload PDF", file_types=[".pdf"]), | |
| outputs=[ | |
| gr.File(label="Download JSON Output"), | |
| gr.Textbox(label="Status Log") | |
| ], | |
| title="LayoutLMv3 PDF Parser", | |
| description="Upload a document to extract Questions, Options, and Passages into structured JSON.", | |
| flagging_mode="never" # <--- This is the fix (renamed from allow_flagging) | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |