Spaces:
Sleeping
Sleeping
| import spaces | |
| import gradio as gr | |
| from huggingface_hub import list_models | |
| from typing import List | |
| import torch | |
| from transformers import DonutProcessor, VisionEncoderDecoderModel | |
| from PIL import Image | |
| import json | |
| import re | |
| import logging | |
| from datasets import load_dataset | |
| import os | |
| # Logging configuration | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Paths to the static image and GIF | |
| README_IMAGE_PATH = os.path.join("figs", "saliencies-merit-dataset.png") | |
| GIF_PATH = os.path.join("figs", "demo_samples.gif") | |
| # Global variables for Donut model, processor, and dataset | |
| donut_model = None | |
| donut_processor = None | |
| dataset = None | |
| def load_merit_dataset(): | |
| global dataset | |
| if dataset is None: | |
| dataset = load_dataset( | |
| "de-Rodrigo/merit", name="en-digital-seq", split="test", num_proc=8 | |
| ) | |
| return dataset | |
| def get_image_from_dataset(index): | |
| global dataset | |
| if dataset is None: | |
| dataset = load_merit_dataset() | |
| image_data = dataset[int(index)]["image"] | |
| return image_data | |
| def get_collection_models(tag: str) -> List[str]: | |
| """Get a list of models from a specific Hugging Face collection.""" | |
| models = list_models(author="de-Rodrigo") | |
| return [model.modelId for model in models if tag in model.tags] | |
| def get_donut(): | |
| global donut_model, donut_processor | |
| if donut_model is None or donut_processor is None: | |
| try: | |
| donut_model = VisionEncoderDecoderModel.from_pretrained( | |
| "de-Rodrigo/donut-merit" | |
| ) | |
| donut_processor = DonutProcessor.from_pretrained("de-Rodrigo/donut-merit") | |
| donut_model = donut_model.to("cuda") | |
| logger.info("Donut model loaded successfully on GPU") | |
| except Exception as e: | |
| logger.error(f"Error loading Donut model: {str(e)}") | |
| raise | |
| return donut_model, donut_processor | |
| def process_image_donut(model, processor, image): | |
| try: | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image) | |
| pixel_values = processor(image, return_tensors="pt").pixel_values.to("cuda") | |
| task_prompt = "<s_cord-v2>" | |
| decoder_input_ids = processor.tokenizer( | |
| task_prompt, add_special_tokens=False, return_tensors="pt" | |
| )["input_ids"].to("cuda") | |
| outputs = model.generate( | |
| pixel_values, | |
| decoder_input_ids=decoder_input_ids, | |
| max_length=model.decoder.config.max_position_embeddings, | |
| early_stopping=True, | |
| pad_token_id=processor.tokenizer.pad_token_id, | |
| eos_token_id=processor.tokenizer.eos_token_id, | |
| use_cache=True, | |
| num_beams=1, | |
| bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
| return_dict_in_generate=True, | |
| ) | |
| sequence = processor.batch_decode(outputs.sequences)[0] | |
| sequence = sequence.replace(processor.tokenizer.eos_token, "").replace( | |
| processor.tokenizer.pad_token, "" | |
| ) | |
| sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() | |
| result = processor.token2json(sequence) | |
| return json.dumps(result, indent=2) | |
| except Exception as e: | |
| logger.error(f"Error processing image with Donut: {str(e)}") | |
| return f"Error: {str(e)}" | |
| def process_image(model_name, image=None, dataset_image_index=None): | |
| if dataset_image_index is not None: | |
| image = get_image_from_dataset(dataset_image_index) | |
| if model_name == "de-Rodrigo/donut-merit": | |
| model, processor = get_donut() | |
| result = process_image_donut(model, processor, image) | |
| else: | |
| # Here you should implement processing for other models | |
| result = f"Processing for model {model_name} not implemented" | |
| return image, result | |
| def update_image(dataset_image_index): | |
| return get_image_from_dataset(dataset_image_index) | |
| if __name__ == "__main__": | |
| # Load the dataset | |
| load_merit_dataset() | |
| models = get_collection_models("saliency") | |
| models.append("de-Rodrigo/donut-merit") | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Saliency Maps with the MERIT Dataset πππ") | |
| gr.Image(value=README_IMAGE_PATH, label="Example Document") | |
| with gr.Tab("Introduction"): | |
| gr.Markdown( | |
| """ | |
| ## Welcome to Saliency Maps with the [MERIT Dataset](https://huggingface.co/datasets/de-Rodrigo/merit) | |
| This space demonstrates the capabilities of different Vision Language models | |
| for document understanding tasks. | |
| ### Key Features: | |
| - Process images from the [MERIT Dataset](https://huggingface.co/datasets/de-Rodrigo/merit) or upload your own image. | |
| - Use a fine-tuned version of the models availabe to extract grades from documents. | |
| - Visualize saliency maps to understand where the model is looking (WIP π οΈ). | |
| """ | |
| ) | |
| gr.Image(value=GIF_PATH, label="Document Understanding Process") | |
| with gr.Tab("Try It Yourself"): | |
| gr.Markdown( | |
| "Select a model and an image from the dataset, or upload your own image." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_dropdown = gr.Dropdown(choices=models, label="Select Model") | |
| dataset_slider = gr.Slider( | |
| minimum=0, | |
| maximum=len(dataset) - 1, | |
| step=1, | |
| label="Dataset Image Index", | |
| ) | |
| upload_image = gr.Image( | |
| type="pil", label="Or Upload Your Own Image" | |
| ) | |
| preview_image = gr.Image(label="Selected/Uploaded Image") | |
| process_button = gr.Button("Process Image") | |
| with gr.Row(): | |
| output_image = gr.Image(label="Processed Image") | |
| output_text = gr.Textbox(label="Result") | |
| # Update preview image when slider changes | |
| dataset_slider.change( | |
| fn=update_image, inputs=[dataset_slider], outputs=[preview_image] | |
| ) | |
| # Update preview image when an image is uploaded | |
| upload_image.change( | |
| fn=lambda x: x, inputs=[upload_image], outputs=[preview_image] | |
| ) | |
| # Process image when button is clicked | |
| process_button.click( | |
| fn=process_image, | |
| inputs=[model_dropdown, upload_image, dataset_slider], | |
| outputs=[output_image, output_text], | |
| ) | |
| demo.launch() | |