|
|
import subprocess |
|
|
|
|
|
subprocess.run( |
|
|
"pip install flash-attn --no-build-isolation", |
|
|
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, |
|
|
shell=True, |
|
|
) |
|
|
import spaces |
|
|
import gradio as gr |
|
|
import re |
|
|
|
|
|
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor |
|
|
from qwen_vl_utils import process_vision_info |
|
|
import torch |
|
|
import os |
|
|
import json |
|
|
from pydantic import BaseModel |
|
|
from typing import Tuple |
|
|
|
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
|
|
|
|
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
|
"Qwen/Qwen2.5-VL-7B-Instruct", |
|
|
torch_dtype=torch.bfloat16, |
|
|
attn_implementation="flash_attention_2", |
|
|
device_map="auto", |
|
|
) |
|
|
processor = AutoProcessor.from_pretrained( |
|
|
"Qwen/Qwen2.5-VL-7B-Instruct", |
|
|
) |
|
|
|
|
|
class GeneralRetrievalQuery(BaseModel): |
|
|
broad_topical_query: str |
|
|
broad_topical_explanation: str |
|
|
specific_detail_query: str |
|
|
specific_detail_explanation: str |
|
|
visual_element_query: str |
|
|
visual_element_explanation: str |
|
|
|
|
|
def extract_json_with_regex(text): |
|
|
|
|
|
pattern = r'```(?:json)?\s*(.+?)\s*```' |
|
|
|
|
|
|
|
|
matches = re.findall(pattern, text, re.DOTALL) |
|
|
|
|
|
if matches: |
|
|
|
|
|
return matches[0] |
|
|
return None |
|
|
|
|
|
def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]: |
|
|
if prompt_name != "general": |
|
|
raise ValueError("Only 'general' prompt is available in this version") |
|
|
|
|
|
prompt = """You are an AI assistant specialized in document retrieval tasks. Given an image of a document page, your task is to generate retrieval queries that someone might use to find this document in a large corpus. |
|
|
|
|
|
Please generate 3 different types of retrieval queries: |
|
|
|
|
|
1. A broad topical query: This should cover the main subject of the document. |
|
|
2. A specific detail query: This should focus on a particular fact, figure, or point made in the document. |
|
|
3. A visual element query: This should reference a chart, graph, image, or other visual component in the document, if present. Don't just reference the name of the visual element but generate a query which this illustration may help answer or be related to. |
|
|
|
|
|
Important guidelines: |
|
|
- Ensure the queries are relevant for retrieval tasks, not just describing the page content. |
|
|
- Frame the queries as if someone is searching for this document, not asking questions about its content. |
|
|
- Make the queries diverse and representative of different search strategies. |
|
|
|
|
|
For each query, also provide a brief explanation of why this query would be effective in retrieving this document. |
|
|
|
|
|
Format your response as a JSON object with the following structure: |
|
|
|
|
|
{ |
|
|
"broad_topical_query": "Your query here", |
|
|
"broad_topical_explanation": "Brief explanation", |
|
|
"specific_detail_query": "Your query here", |
|
|
"specific_detail_explanation": "Brief explanation", |
|
|
"visual_element_query": "Your query here", |
|
|
"visual_element_explanation": "Brief explanation" |
|
|
} |
|
|
|
|
|
If there are no relevant visual elements, replace the third query with another specific detail query. |
|
|
|
|
|
Here is the document image to analyze: |
|
|
<image> |
|
|
|
|
|
Generate the queries based on this image and provide the response in the specified JSON format.""" |
|
|
|
|
|
return prompt, GeneralRetrievalQuery |
|
|
|
|
|
|
|
|
prompt, pydantic_model = get_retrieval_prompt("general") |
|
|
|
|
|
def _prep_data_for_input(image): |
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": image, |
|
|
}, |
|
|
{"type": "text", "text": prompt}, |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
text = processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
|
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
|
|
|
return processor( |
|
|
text=[text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
@spaces.GPU |
|
|
def generate_response(image): |
|
|
inputs = _prep_data_for_input(image) |
|
|
inputs = inputs.to("cuda") |
|
|
|
|
|
generated_ids = model.generate(**inputs, max_new_tokens=200) |
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids) :] |
|
|
for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
|
] |
|
|
|
|
|
output_text = processor.batch_decode( |
|
|
generated_ids_trimmed, |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=False, |
|
|
)[0] |
|
|
|
|
|
try: |
|
|
|
|
|
json_str = extract_json_with_regex(output_text) |
|
|
if json_str: |
|
|
parsed = json.loads(json_str) |
|
|
return json.dumps(parsed, indent=2) |
|
|
|
|
|
parsed = json.loads(output_text) |
|
|
return json.dumps(parsed, indent=2) |
|
|
except Exception: |
|
|
gr.Warning("Failed to parse JSON from output") |
|
|
return output_text |
|
|
|
|
|
title = "ColPali Query Generator using Qwen2.5-VL" |
|
|
description = """[ColPali](https://huggingface.co/papers/2407.01449) is a very exciting new approach to multimodal document retrieval which aims to replace existing document retrievers which often rely on an OCR step with an end-to-end multimodal approach. |
|
|
|
|
|
To train or fine-tune a ColPali model, we need a dataset of image-text pairs which represent the document images and the relevant text queries which those documents should match. |
|
|
To make the ColPali models work even better we might want a dataset of query/image document pairs related to our domain or task. |
|
|
|
|
|
One way in which we might go about generating such a dataset is to use a VLM to generate synthetic queries for us. |
|
|
This space uses the [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) VLM model to generate queries for a document, based on an input document image. |
|
|
|
|
|
**Note** there is a lot of scope for improving to prompts and the quality of the generated queries! If you have any suggestions for improvements please [open a Discussion](https://huggingface.co/spaces/davanstrien/ColPali-Query-Generator/discussions/new)! |
|
|
|
|
|
This [blog post](https://danielvanstrien.xyz/posts/post-with-code/colpali/2024-09-23-generate_colpali_dataset.html) gives an overview of how you can use this kind of approach to generate a full dataset for fine-tuning ColPali models. |
|
|
|
|
|
If you want to convert a PDF(s) to a dataset of page images you can try out the [ PDFs to Page Images Converter](https://huggingface.co/spaces/Dataset-Creation-Tools/pdf-to-page-images-dataset) Space. |
|
|
""" |
|
|
|
|
|
examples = [ |
|
|
"examples/Approche_no_13_1977.pdf_page_22.jpg", |
|
|
"examples/SRCCL_Technical-Summary.pdf_page_7.jpg", |
|
|
] |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=generate_response, |
|
|
inputs=gr.Image(type="pil"), |
|
|
outputs=gr.Text(), |
|
|
title=title, |
|
|
description=description, |
|
|
examples=examples, |
|
|
) |
|
|
demo.launch() |