|
|
import os |
|
|
import faiss |
|
|
import pickle |
|
|
import numpy as np |
|
|
from typing import List, Dict |
|
|
from docarray import Document, DocumentArray |
|
|
from jina import Executor, requests |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer, BlipProcessor, BlipForConditionalGeneration, BitsAndBytesConfig |
|
|
from pdfminer.high_level import extract_text |
|
|
import fitz |
|
|
from PIL import Image |
|
|
import traceback |
|
|
import torch |
|
|
import re |
|
|
import io |
|
|
|
|
|
|
|
|
class MultimodalRAGExecutor(Executor): |
|
|
def __init__( |
|
|
self, |
|
|
llm_model_name: str = "Qwen/Qwen2.5-3B-Instruct", |
|
|
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", |
|
|
vision_model: str = "Salesforce/blip-image-captioning-base", |
|
|
index_file: str = "faiss_index.bin", |
|
|
metadata_file: str = "metadata.pkl", |
|
|
dim: int = 384, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.llm_model_name = llm_model_name |
|
|
self.embedding_model = embedding_model |
|
|
self.vision_model_name = vision_model |
|
|
self.index_file = index_file |
|
|
self.metadata_file = metadata_file |
|
|
self.dim = dim |
|
|
|
|
|
self.hf_token = os.getenv("HUGGINGFACE_TOKEN", "") |
|
|
if self.hf_token: |
|
|
print(f"Token: {self.hf_token[:10]}...") |
|
|
else: |
|
|
print("No token") |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Device: {self.device}") |
|
|
|
|
|
|
|
|
print(f"Loading embeddings: {embedding_model}") |
|
|
self.embedder = SentenceTransformer(embedding_model) |
|
|
print("Embeddings loaded") |
|
|
|
|
|
|
|
|
print(f"Loading vision: {vision_model}") |
|
|
try: |
|
|
self.vision_processor = BlipProcessor.from_pretrained(vision_model) |
|
|
self.vision_model = BlipForConditionalGeneration.from_pretrained( |
|
|
vision_model, |
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 |
|
|
).to(self.device) |
|
|
self.vision_model.eval() |
|
|
print("Vision loaded") |
|
|
except Exception as e: |
|
|
print(f"Vision error: {e}") |
|
|
self.vision_processor = None |
|
|
self.vision_model = None |
|
|
|
|
|
|
|
|
print(f"Loading text: {llm_model_name}") |
|
|
try: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(llm_model_name) |
|
|
|
|
|
if self.device == "cuda": |
|
|
quantization_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=torch.float16, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4" |
|
|
) |
|
|
|
|
|
self.llm_model = AutoModelForCausalLM.from_pretrained( |
|
|
llm_model_name, |
|
|
quantization_config=quantization_config, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16 |
|
|
) |
|
|
else: |
|
|
self.llm_model = AutoModelForCausalLM.from_pretrained( |
|
|
llm_model_name, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float32, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
|
|
|
self.llm_model.eval() |
|
|
print("Text model loaded") |
|
|
except Exception as e: |
|
|
print(f"Text error: {e}") |
|
|
self.llm_model = None |
|
|
self.tokenizer = None |
|
|
|
|
|
self._load_or_create_index() |
|
|
|
|
|
def _load_or_create_index(self): |
|
|
if os.path.exists(self.index_file) and os.path.exists(self.metadata_file): |
|
|
try: |
|
|
self.index = faiss.read_index(self.index_file) |
|
|
with open(self.metadata_file, "rb") as f: |
|
|
self.metadata = pickle.load(f) |
|
|
print(f"Index loaded: {self.index.ntotal} vectors") |
|
|
except Exception as e: |
|
|
print(f"Index error: {e}") |
|
|
self.index = faiss.IndexFlatL2(self.dim) |
|
|
self.metadata = [] |
|
|
else: |
|
|
self.index = faiss.IndexFlatL2(self.dim) |
|
|
self.metadata = [] |
|
|
print("New index created") |
|
|
|
|
|
def _get_embedding(self, text: str) -> np.ndarray: |
|
|
embedding = self.embedder.encode(text, convert_to_numpy=True) |
|
|
return embedding.astype(np.float32) |
|
|
|
|
|
def _analyze_image(self, image: Image.Image) -> str: |
|
|
if not self.vision_processor or not self.vision_model: |
|
|
return "Image analysis unavailable" |
|
|
|
|
|
try: |
|
|
inputs = self.vision_processor(image, return_tensors="pt").to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
out = self.vision_model.generate(**inputs, max_length=100) |
|
|
|
|
|
caption = self.vision_processor.decode(out[0], skip_special_tokens=True) |
|
|
|
|
|
text = "a detailed description of" |
|
|
inputs = self.vision_processor(image, text, return_tensors="pt").to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
out = self.vision_model.generate(**inputs, max_length=120) |
|
|
|
|
|
detailed = self.vision_processor.decode(out[0], skip_special_tokens=True) |
|
|
|
|
|
return f"Caption: {caption}. Details: {detailed}" |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Image error: {e}") |
|
|
return "Image analysis failed" |
|
|
|
|
|
def _extract_text_from_pdf(self, pdf_path: str) -> List[str]: |
|
|
texts = [] |
|
|
try: |
|
|
doc = fitz.open(pdf_path) |
|
|
for page_num, page in enumerate(doc, start=1): |
|
|
text = page.get_text("text") |
|
|
if text and text.strip(): |
|
|
texts.append(f"Page {page_num}:\n{text.strip()}") |
|
|
doc.close() |
|
|
except Exception as e: |
|
|
print(f"Text extraction error: {e}") |
|
|
return texts |
|
|
|
|
|
def _extract_images_from_pdf(self, pdf_path: str) -> List[Dict]: |
|
|
images_data = [] |
|
|
try: |
|
|
doc = fitz.open(pdf_path) |
|
|
for page_num, page in enumerate(doc, start=1): |
|
|
image_list = page.get_images(full=True) |
|
|
|
|
|
for img_idx, img in enumerate(image_list): |
|
|
try: |
|
|
xref = img[0] |
|
|
base_image = doc.extract_image(xref) |
|
|
image_bytes = base_image["image"] |
|
|
|
|
|
pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB') |
|
|
|
|
|
width, height = pil_image.size |
|
|
if width >= 100 and height >= 100: |
|
|
images_data.append({ |
|
|
'image': pil_image, |
|
|
'page': page_num, |
|
|
'index': img_idx |
|
|
}) |
|
|
except Exception as e: |
|
|
print(f"Image extract error page {page_num}: {e}") |
|
|
continue |
|
|
|
|
|
doc.close() |
|
|
except Exception as e: |
|
|
print(f"PDF image error: {e}") |
|
|
|
|
|
return images_data |
|
|
|
|
|
def _generate_answer(self, prompt: str, context: str) -> str: |
|
|
if not self.llm_model or not self.tokenizer: |
|
|
return self._extractive_answer(prompt, context) |
|
|
|
|
|
try: |
|
|
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.llm_model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.llm_model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=256, |
|
|
temperature=0.3, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
repetition_penalty=1.1 |
|
|
) |
|
|
|
|
|
generated_ids = [ |
|
|
output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, outputs) |
|
|
] |
|
|
|
|
|
answer = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
return answer.strip() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Generation error: {e}") |
|
|
return self._extractive_answer(prompt, context) |
|
|
|
|
|
def _extractive_answer(self, query: str, context: str) -> str: |
|
|
sentences = re.split(r'[.!?]+', context) |
|
|
sentences = [s.strip() for s in sentences if len(s.strip()) > 20] |
|
|
|
|
|
query_words = set(query.lower().split()) |
|
|
scored = [] |
|
|
|
|
|
for sent in sentences[:30]: |
|
|
sent_words = set(sent.lower().split()) |
|
|
overlap = len(query_words.intersection(sent_words)) |
|
|
|
|
|
for word in query_words: |
|
|
if len(word) > 4 and word in sent.lower(): |
|
|
overlap += 2 |
|
|
|
|
|
if overlap > 0: |
|
|
scored.append((overlap, sent)) |
|
|
|
|
|
scored.sort(reverse=True, key=lambda x: x[0]) |
|
|
|
|
|
if scored: |
|
|
top_sentences = [s for _, s in scored[:3]] |
|
|
return ". ".join(top_sentences) + "." |
|
|
|
|
|
return "Could not find relevant information" |
|
|
|
|
|
@requests(on="/upload") |
|
|
def upload(self, docs: DocumentArray, **kwargs): |
|
|
try: |
|
|
for doc in docs: |
|
|
file_path = doc.uri |
|
|
if file_path.startswith("file://"): |
|
|
file_path = file_path.replace("file://", "") |
|
|
|
|
|
if not os.path.exists(file_path) or not file_path.endswith(".pdf"): |
|
|
continue |
|
|
|
|
|
print(f"Processing: {file_path}") |
|
|
|
|
|
text_chunks = self._extract_text_from_pdf(file_path) |
|
|
|
|
|
text_count = 0 |
|
|
for chunk in text_chunks: |
|
|
paragraphs = chunk.split("\n\n") |
|
|
for para in paragraphs: |
|
|
if para.strip() and len(para.strip()) > 50: |
|
|
emb = self._get_embedding(para.strip()) |
|
|
self.index.add(np.array([emb])) |
|
|
self.metadata.append({ |
|
|
"type": "text", |
|
|
"content": para.strip() |
|
|
}) |
|
|
text_count += 1 |
|
|
|
|
|
print(f"Indexed {text_count} text chunks") |
|
|
|
|
|
images_data = self._extract_images_from_pdf(file_path) |
|
|
image_count = 0 |
|
|
|
|
|
for img_data in images_data: |
|
|
description = self._analyze_image(img_data['image']) |
|
|
|
|
|
img_path = os.path.abspath( |
|
|
f"img_p{img_data['page']}_i{img_data['index']}.png" |
|
|
) |
|
|
img_data['image'].save(img_path, "PNG") |
|
|
|
|
|
embed_text = f"Image from page {img_data['page']}: {description}" |
|
|
emb = self._get_embedding(embed_text) |
|
|
self.index.add(np.array([emb])) |
|
|
self.metadata.append({ |
|
|
"type": "image", |
|
|
"content": f"file://{img_path}", |
|
|
"description": description, |
|
|
"page": img_data['page'] |
|
|
}) |
|
|
image_count += 1 |
|
|
|
|
|
print(f"Analyzed {image_count} images") |
|
|
|
|
|
|
|
|
faiss.write_index(self.index, self.index_file) |
|
|
with open(self.metadata_file, "wb") as f: |
|
|
pickle.dump(self.metadata, f) |
|
|
|
|
|
summary = f"Upload complete!\n" |
|
|
summary += f"Total vectors: {self.index.ntotal}\n" |
|
|
summary += f"Text chunks: {text_count}\n" |
|
|
summary += f"Images: {image_count}\n" |
|
|
summary += f"Using Qwen 2.5 & BLIP" |
|
|
|
|
|
return DocumentArray([Document(text=summary)]) |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Upload failed:\n{traceback.format_exc()}" |
|
|
print(error_msg) |
|
|
return DocumentArray([Document(text=error_msg)]) |
|
|
|
|
|
@requests(on="/query") |
|
|
def query(self, docs: DocumentArray, **kwargs): |
|
|
results = DocumentArray() |
|
|
|
|
|
if self.index.ntotal == 0: |
|
|
return DocumentArray([ |
|
|
Document(text="No documents uploaded. Please upload PDF first.") |
|
|
]) |
|
|
|
|
|
for doc in docs: |
|
|
try: |
|
|
query_text = doc.text |
|
|
|
|
|
query_emb = self._get_embedding(query_text) |
|
|
D, I = self.index.search(np.array([query_emb]), k=10) |
|
|
|
|
|
context_parts = [] |
|
|
matched_images = [] |
|
|
image_descriptions = [] |
|
|
|
|
|
for idx in I[0]: |
|
|
if idx < len(self.metadata): |
|
|
meta = self.metadata[idx] |
|
|
if meta["type"] == "text": |
|
|
context_parts.append(meta["content"]) |
|
|
elif meta["type"] == "image": |
|
|
matched_images.append(Document(uri=meta["content"])) |
|
|
image_descriptions.append( |
|
|
f"[Image Page {meta.get('page', '?')}]: {meta['description']}" |
|
|
) |
|
|
|
|
|
context_text = "\n\n".join(context_parts[:5]) |
|
|
|
|
|
if image_descriptions: |
|
|
context_text += "\n\nRelevant Images:\n" + "\n".join(image_descriptions[:3]) |
|
|
|
|
|
if len(context_text) > 2500: |
|
|
context_text = context_text[:2500] + "..." |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": "You are a helpful research assistant. Answer accurately based on context." |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": f"""Context from research paper: |
|
|
{context_text} |
|
|
|
|
|
Question: {query_text} |
|
|
|
|
|
Provide a clear and accurate answer based only on the context.""" |
|
|
} |
|
|
] |
|
|
|
|
|
prompt = self.tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
answer = self._generate_answer(prompt, context_text) |
|
|
|
|
|
|
|
|
answer = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', answer, flags=re.DOTALL) |
|
|
answer = re.sub(r'^(Question|Answer|Context):\s*', '', answer, flags=re.IGNORECASE) |
|
|
answer = answer.strip() |
|
|
|
|
|
answer_doc = Document(text=answer) |
|
|
if matched_images: |
|
|
answer_doc.chunks = DocumentArray(matched_images[:4]) |
|
|
|
|
|
results.append(answer_doc) |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Query failed: {str(e)}\n{traceback.format_exc()}" |
|
|
print(error_msg) |
|
|
results.append(Document(text=error_msg)) |
|
|
|
|
|
return results |
|
|
|
|
|
@requests(on="/stats") |
|
|
def stats(self, docs: DocumentArray, **kwargs): |
|
|
text_count = sum(1 for m in self.metadata if m["type"] == "text") |
|
|
image_count = sum(1 for m in self.metadata if m["type"] == "image") |
|
|
|
|
|
stats_text = ( |
|
|
f"Index Statistics:\n" |
|
|
f"Total vectors: {self.index.ntotal}\n" |
|
|
f"Text chunks: {text_count}\n" |
|
|
f"Images: {image_count}\n" |
|
|
f"Using Qwen 2.5 & BLIP" |
|
|
) |
|
|
|
|
|
return DocumentArray([Document(text=stats_text)]) |
|
|
|
|
|
@requests(on="/reset") |
|
|
def reset(self, docs: DocumentArray, **kwargs): |
|
|
try: |
|
|
self.index = faiss.IndexFlatL2(self.dim) |
|
|
self.metadata = [] |
|
|
|
|
|
if os.path.exists(self.index_file): |
|
|
os.remove(self.index_file) |
|
|
if os.path.exists(self.metadata_file): |
|
|
os.remove(self.metadata_file) |
|
|
|
|
|
return DocumentArray([Document(text="Index reset successfully")]) |
|
|
except Exception as e: |
|
|
return DocumentArray([Document(text=f"Reset failed: {str(e)}")])] |