Girinath11's picture
Update executor.py
3415aa5 verified
import os
import faiss
import pickle
import numpy as np
from typing import List, Dict
from docarray import Document, DocumentArray
from jina import Executor, requests
from sentence_transformers import SentenceTransformer
from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer, BlipProcessor, BlipForConditionalGeneration, BitsAndBytesConfig
from pdfminer.high_level import extract_text
import fitz
from PIL import Image
import traceback
import torch
import re
import io
class MultimodalRAGExecutor(Executor):
def __init__(
self,
llm_model_name: str = "Qwen/Qwen2.5-3B-Instruct",
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
vision_model: str = "Salesforce/blip-image-captioning-base",
index_file: str = "faiss_index.bin",
metadata_file: str = "metadata.pkl",
dim: int = 384,
**kwargs,
):
super().__init__(**kwargs)
self.llm_model_name = llm_model_name
self.embedding_model = embedding_model
self.vision_model_name = vision_model
self.index_file = index_file
self.metadata_file = metadata_file
self.dim = dim
self.hf_token = os.getenv("HUGGINGFACE_TOKEN", "")
if self.hf_token:
print(f"Token: {self.hf_token[:10]}...")
else:
print("No token")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {self.device}")
# Load embedding model
print(f"Loading embeddings: {embedding_model}")
self.embedder = SentenceTransformer(embedding_model)
print("Embeddings loaded")
# Load BLIP vision
print(f"Loading vision: {vision_model}")
try:
self.vision_processor = BlipProcessor.from_pretrained(vision_model)
self.vision_model = BlipForConditionalGeneration.from_pretrained(
vision_model,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
self.vision_model.eval()
print("Vision loaded")
except Exception as e:
print(f"Vision error: {e}")
self.vision_processor = None
self.vision_model = None
# Load Qwen text model
print(f"Loading text: {llm_model_name}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
if self.device == "cuda":
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
self.llm_model = AutoModelForCausalLM.from_pretrained(
llm_model_name,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.float16
)
else:
self.llm_model = AutoModelForCausalLM.from_pretrained(
llm_model_name,
device_map="auto",
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
self.llm_model.eval()
print("Text model loaded")
except Exception as e:
print(f"Text error: {e}")
self.llm_model = None
self.tokenizer = None
self._load_or_create_index()
def _load_or_create_index(self):
if os.path.exists(self.index_file) and os.path.exists(self.metadata_file):
try:
self.index = faiss.read_index(self.index_file)
with open(self.metadata_file, "rb") as f:
self.metadata = pickle.load(f)
print(f"Index loaded: {self.index.ntotal} vectors")
except Exception as e:
print(f"Index error: {e}")
self.index = faiss.IndexFlatL2(self.dim)
self.metadata = []
else:
self.index = faiss.IndexFlatL2(self.dim)
self.metadata = []
print("New index created")
def _get_embedding(self, text: str) -> np.ndarray:
embedding = self.embedder.encode(text, convert_to_numpy=True)
return embedding.astype(np.float32)
def _analyze_image(self, image: Image.Image) -> str:
if not self.vision_processor or not self.vision_model:
return "Image analysis unavailable"
try:
inputs = self.vision_processor(image, return_tensors="pt").to(self.device)
with torch.no_grad():
out = self.vision_model.generate(**inputs, max_length=100)
caption = self.vision_processor.decode(out[0], skip_special_tokens=True)
text = "a detailed description of"
inputs = self.vision_processor(image, text, return_tensors="pt").to(self.device)
with torch.no_grad():
out = self.vision_model.generate(**inputs, max_length=120)
detailed = self.vision_processor.decode(out[0], skip_special_tokens=True)
return f"Caption: {caption}. Details: {detailed}"
except Exception as e:
print(f"Image error: {e}")
return "Image analysis failed"
def _extract_text_from_pdf(self, pdf_path: str) -> List[str]:
texts = []
try:
doc = fitz.open(pdf_path)
for page_num, page in enumerate(doc, start=1):
text = page.get_text("text")
if text and text.strip():
texts.append(f"Page {page_num}:\n{text.strip()}")
doc.close()
except Exception as e:
print(f"Text extraction error: {e}")
return texts
def _extract_images_from_pdf(self, pdf_path: str) -> List[Dict]:
images_data = []
try:
doc = fitz.open(pdf_path)
for page_num, page in enumerate(doc, start=1):
image_list = page.get_images(full=True)
for img_idx, img in enumerate(image_list):
try:
xref = img[0]
base_image = doc.extract_image(xref)
image_bytes = base_image["image"]
pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
width, height = pil_image.size
if width >= 100 and height >= 100:
images_data.append({
'image': pil_image,
'page': page_num,
'index': img_idx
})
except Exception as e:
print(f"Image extract error page {page_num}: {e}")
continue
doc.close()
except Exception as e:
print(f"PDF image error: {e}")
return images_data
def _generate_answer(self, prompt: str, context: str) -> str:
if not self.llm_model or not self.tokenizer:
return self._extractive_answer(prompt, context)
try:
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.llm_model.device)
with torch.no_grad():
outputs = self.llm_model.generate(
**inputs,
max_new_tokens=256,
temperature=0.3,
do_sample=True,
top_p=0.9,
repetition_penalty=1.1
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, outputs)
]
answer = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return answer.strip()
except Exception as e:
print(f"Generation error: {e}")
return self._extractive_answer(prompt, context)
def _extractive_answer(self, query: str, context: str) -> str:
sentences = re.split(r'[.!?]+', context)
sentences = [s.strip() for s in sentences if len(s.strip()) > 20]
query_words = set(query.lower().split())
scored = []
for sent in sentences[:30]:
sent_words = set(sent.lower().split())
overlap = len(query_words.intersection(sent_words))
for word in query_words:
if len(word) > 4 and word in sent.lower():
overlap += 2
if overlap > 0:
scored.append((overlap, sent))
scored.sort(reverse=True, key=lambda x: x[0])
if scored:
top_sentences = [s for _, s in scored[:3]]
return ". ".join(top_sentences) + "."
return "Could not find relevant information"
@requests(on="/upload")
def upload(self, docs: DocumentArray, **kwargs):
try:
for doc in docs:
file_path = doc.uri
if file_path.startswith("file://"):
file_path = file_path.replace("file://", "")
if not os.path.exists(file_path) or not file_path.endswith(".pdf"):
continue
print(f"Processing: {file_path}")
text_chunks = self._extract_text_from_pdf(file_path)
text_count = 0
for chunk in text_chunks:
paragraphs = chunk.split("\n\n")
for para in paragraphs:
if para.strip() and len(para.strip()) > 50:
emb = self._get_embedding(para.strip())
self.index.add(np.array([emb]))
self.metadata.append({
"type": "text",
"content": para.strip()
})
text_count += 1
print(f"Indexed {text_count} text chunks")
images_data = self._extract_images_from_pdf(file_path)
image_count = 0
for img_data in images_data:
description = self._analyze_image(img_data['image'])
img_path = os.path.abspath(
f"img_p{img_data['page']}_i{img_data['index']}.png"
)
img_data['image'].save(img_path, "PNG")
embed_text = f"Image from page {img_data['page']}: {description}"
emb = self._get_embedding(embed_text)
self.index.add(np.array([emb]))
self.metadata.append({
"type": "image",
"content": f"file://{img_path}",
"description": description,
"page": img_data['page']
})
image_count += 1
print(f"Analyzed {image_count} images")
# Save index
faiss.write_index(self.index, self.index_file)
with open(self.metadata_file, "wb") as f:
pickle.dump(self.metadata, f)
summary = f"Upload complete!\n"
summary += f"Total vectors: {self.index.ntotal}\n"
summary += f"Text chunks: {text_count}\n"
summary += f"Images: {image_count}\n"
summary += f"Using Qwen 2.5 & BLIP"
return DocumentArray([Document(text=summary)])
except Exception as e:
error_msg = f"Upload failed:\n{traceback.format_exc()}"
print(error_msg)
return DocumentArray([Document(text=error_msg)])
@requests(on="/query")
def query(self, docs: DocumentArray, **kwargs):
results = DocumentArray()
if self.index.ntotal == 0:
return DocumentArray([
Document(text="No documents uploaded. Please upload PDF first.")
])
for doc in docs:
try:
query_text = doc.text
query_emb = self._get_embedding(query_text)
D, I = self.index.search(np.array([query_emb]), k=10)
context_parts = []
matched_images = []
image_descriptions = []
for idx in I[0]:
if idx < len(self.metadata):
meta = self.metadata[idx]
if meta["type"] == "text":
context_parts.append(meta["content"])
elif meta["type"] == "image":
matched_images.append(Document(uri=meta["content"]))
image_descriptions.append(
f"[Image Page {meta.get('page', '?')}]: {meta['description']}"
)
context_text = "\n\n".join(context_parts[:5])
if image_descriptions:
context_text += "\n\nRelevant Images:\n" + "\n".join(image_descriptions[:3])
if len(context_text) > 2500:
context_text = context_text[:2500] + "..."
# Qwen chat format
messages = [
{
"role": "system",
"content": "You are a helpful research assistant. Answer accurately based on context."
},
{
"role": "user",
"content": f"""Context from research paper:
{context_text}
Question: {query_text}
Provide a clear and accurate answer based only on the context."""
}
]
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
answer = self._generate_answer(prompt, context_text)
# Clean answer
answer = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', answer, flags=re.DOTALL)
answer = re.sub(r'^(Question|Answer|Context):\s*', '', answer, flags=re.IGNORECASE)
answer = answer.strip()
answer_doc = Document(text=answer)
if matched_images:
answer_doc.chunks = DocumentArray(matched_images[:4])
results.append(answer_doc)
except Exception as e:
error_msg = f"Query failed: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
results.append(Document(text=error_msg))
return results
@requests(on="/stats")
def stats(self, docs: DocumentArray, **kwargs):
text_count = sum(1 for m in self.metadata if m["type"] == "text")
image_count = sum(1 for m in self.metadata if m["type"] == "image")
stats_text = (
f"Index Statistics:\n"
f"Total vectors: {self.index.ntotal}\n"
f"Text chunks: {text_count}\n"
f"Images: {image_count}\n"
f"Using Qwen 2.5 & BLIP"
)
return DocumentArray([Document(text=stats_text)])
@requests(on="/reset")
def reset(self, docs: DocumentArray, **kwargs):
try:
self.index = faiss.IndexFlatL2(self.dim)
self.metadata = []
if os.path.exists(self.index_file):
os.remove(self.index_file)
if os.path.exists(self.metadata_file):
os.remove(self.metadata_file)
return DocumentArray([Document(text="Index reset successfully")])
except Exception as e:
return DocumentArray([Document(text=f"Reset failed: {str(e)}")])]