import os import faiss import pickle import numpy as np from typing import List, Dict from docarray import Document, DocumentArray from jina import Executor, requests from sentence_transformers import SentenceTransformer from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer, BlipProcessor, BlipForConditionalGeneration, BitsAndBytesConfig from pdfminer.high_level import extract_text import fitz from PIL import Image import traceback import torch import re import io class MultimodalRAGExecutor(Executor): def __init__( self, llm_model_name: str = "Qwen/Qwen2.5-3B-Instruct", embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", vision_model: str = "Salesforce/blip-image-captioning-base", index_file: str = "faiss_index.bin", metadata_file: str = "metadata.pkl", dim: int = 384, **kwargs, ): super().__init__(**kwargs) self.llm_model_name = llm_model_name self.embedding_model = embedding_model self.vision_model_name = vision_model self.index_file = index_file self.metadata_file = metadata_file self.dim = dim self.hf_token = os.getenv("HUGGINGFACE_TOKEN", "") if self.hf_token: print(f"Token: {self.hf_token[:10]}...") else: print("No token") self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {self.device}") # Load embedding model print(f"Loading embeddings: {embedding_model}") self.embedder = SentenceTransformer(embedding_model) print("Embeddings loaded") # Load BLIP vision print(f"Loading vision: {vision_model}") try: self.vision_processor = BlipProcessor.from_pretrained(vision_model) self.vision_model = BlipForConditionalGeneration.from_pretrained( vision_model, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) self.vision_model.eval() print("Vision loaded") except Exception as e: print(f"Vision error: {e}") self.vision_processor = None self.vision_model = None # Load Qwen text model print(f"Loading text: {llm_model_name}") try: self.tokenizer = AutoTokenizer.from_pretrained(llm_model_name) if self.device == "cuda": quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) self.llm_model = AutoModelForCausalLM.from_pretrained( llm_model_name, quantization_config=quantization_config, device_map="auto", torch_dtype=torch.float16 ) else: self.llm_model = AutoModelForCausalLM.from_pretrained( llm_model_name, device_map="auto", torch_dtype=torch.float32, low_cpu_mem_usage=True ) self.llm_model.eval() print("Text model loaded") except Exception as e: print(f"Text error: {e}") self.llm_model = None self.tokenizer = None self._load_or_create_index() def _load_or_create_index(self): if os.path.exists(self.index_file) and os.path.exists(self.metadata_file): try: self.index = faiss.read_index(self.index_file) with open(self.metadata_file, "rb") as f: self.metadata = pickle.load(f) print(f"Index loaded: {self.index.ntotal} vectors") except Exception as e: print(f"Index error: {e}") self.index = faiss.IndexFlatL2(self.dim) self.metadata = [] else: self.index = faiss.IndexFlatL2(self.dim) self.metadata = [] print("New index created") def _get_embedding(self, text: str) -> np.ndarray: embedding = self.embedder.encode(text, convert_to_numpy=True) return embedding.astype(np.float32) def _analyze_image(self, image: Image.Image) -> str: if not self.vision_processor or not self.vision_model: return "Image analysis unavailable" try: inputs = self.vision_processor(image, return_tensors="pt").to(self.device) with torch.no_grad(): out = self.vision_model.generate(**inputs, max_length=100) caption = self.vision_processor.decode(out[0], skip_special_tokens=True) text = "a detailed description of" inputs = self.vision_processor(image, text, return_tensors="pt").to(self.device) with torch.no_grad(): out = self.vision_model.generate(**inputs, max_length=120) detailed = self.vision_processor.decode(out[0], skip_special_tokens=True) return f"Caption: {caption}. Details: {detailed}" except Exception as e: print(f"Image error: {e}") return "Image analysis failed" def _extract_text_from_pdf(self, pdf_path: str) -> List[str]: texts = [] try: doc = fitz.open(pdf_path) for page_num, page in enumerate(doc, start=1): text = page.get_text("text") if text and text.strip(): texts.append(f"Page {page_num}:\n{text.strip()}") doc.close() except Exception as e: print(f"Text extraction error: {e}") return texts def _extract_images_from_pdf(self, pdf_path: str) -> List[Dict]: images_data = [] try: doc = fitz.open(pdf_path) for page_num, page in enumerate(doc, start=1): image_list = page.get_images(full=True) for img_idx, img in enumerate(image_list): try: xref = img[0] base_image = doc.extract_image(xref) image_bytes = base_image["image"] pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB') width, height = pil_image.size if width >= 100 and height >= 100: images_data.append({ 'image': pil_image, 'page': page_num, 'index': img_idx }) except Exception as e: print(f"Image extract error page {page_num}: {e}") continue doc.close() except Exception as e: print(f"PDF image error: {e}") return images_data def _generate_answer(self, prompt: str, context: str) -> str: if not self.llm_model or not self.tokenizer: return self._extractive_answer(prompt, context) try: inputs = self.tokenizer([prompt], return_tensors="pt").to(self.llm_model.device) with torch.no_grad(): outputs = self.llm_model.generate( **inputs, max_new_tokens=256, temperature=0.3, do_sample=True, top_p=0.9, repetition_penalty=1.1 ) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, outputs) ] answer = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return answer.strip() except Exception as e: print(f"Generation error: {e}") return self._extractive_answer(prompt, context) def _extractive_answer(self, query: str, context: str) -> str: sentences = re.split(r'[.!?]+', context) sentences = [s.strip() for s in sentences if len(s.strip()) > 20] query_words = set(query.lower().split()) scored = [] for sent in sentences[:30]: sent_words = set(sent.lower().split()) overlap = len(query_words.intersection(sent_words)) for word in query_words: if len(word) > 4 and word in sent.lower(): overlap += 2 if overlap > 0: scored.append((overlap, sent)) scored.sort(reverse=True, key=lambda x: x[0]) if scored: top_sentences = [s for _, s in scored[:3]] return ". ".join(top_sentences) + "." return "Could not find relevant information" @requests(on="/upload") def upload(self, docs: DocumentArray, **kwargs): try: for doc in docs: file_path = doc.uri if file_path.startswith("file://"): file_path = file_path.replace("file://", "") if not os.path.exists(file_path) or not file_path.endswith(".pdf"): continue print(f"Processing: {file_path}") text_chunks = self._extract_text_from_pdf(file_path) text_count = 0 for chunk in text_chunks: paragraphs = chunk.split("\n\n") for para in paragraphs: if para.strip() and len(para.strip()) > 50: emb = self._get_embedding(para.strip()) self.index.add(np.array([emb])) self.metadata.append({ "type": "text", "content": para.strip() }) text_count += 1 print(f"Indexed {text_count} text chunks") images_data = self._extract_images_from_pdf(file_path) image_count = 0 for img_data in images_data: description = self._analyze_image(img_data['image']) img_path = os.path.abspath( f"img_p{img_data['page']}_i{img_data['index']}.png" ) img_data['image'].save(img_path, "PNG") embed_text = f"Image from page {img_data['page']}: {description}" emb = self._get_embedding(embed_text) self.index.add(np.array([emb])) self.metadata.append({ "type": "image", "content": f"file://{img_path}", "description": description, "page": img_data['page'] }) image_count += 1 print(f"Analyzed {image_count} images") # Save index faiss.write_index(self.index, self.index_file) with open(self.metadata_file, "wb") as f: pickle.dump(self.metadata, f) summary = f"Upload complete!\n" summary += f"Total vectors: {self.index.ntotal}\n" summary += f"Text chunks: {text_count}\n" summary += f"Images: {image_count}\n" summary += f"Using Qwen 2.5 & BLIP" return DocumentArray([Document(text=summary)]) except Exception as e: error_msg = f"Upload failed:\n{traceback.format_exc()}" print(error_msg) return DocumentArray([Document(text=error_msg)]) @requests(on="/query") def query(self, docs: DocumentArray, **kwargs): results = DocumentArray() if self.index.ntotal == 0: return DocumentArray([ Document(text="No documents uploaded. Please upload PDF first.") ]) for doc in docs: try: query_text = doc.text query_emb = self._get_embedding(query_text) D, I = self.index.search(np.array([query_emb]), k=10) context_parts = [] matched_images = [] image_descriptions = [] for idx in I[0]: if idx < len(self.metadata): meta = self.metadata[idx] if meta["type"] == "text": context_parts.append(meta["content"]) elif meta["type"] == "image": matched_images.append(Document(uri=meta["content"])) image_descriptions.append( f"[Image Page {meta.get('page', '?')}]: {meta['description']}" ) context_text = "\n\n".join(context_parts[:5]) if image_descriptions: context_text += "\n\nRelevant Images:\n" + "\n".join(image_descriptions[:3]) if len(context_text) > 2500: context_text = context_text[:2500] + "..." # Qwen chat format messages = [ { "role": "system", "content": "You are a helpful research assistant. Answer accurately based on context." }, { "role": "user", "content": f"""Context from research paper: {context_text} Question: {query_text} Provide a clear and accurate answer based only on the context.""" } ] prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) answer = self._generate_answer(prompt, context_text) # Clean answer answer = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', answer, flags=re.DOTALL) answer = re.sub(r'^(Question|Answer|Context):\s*', '', answer, flags=re.IGNORECASE) answer = answer.strip() answer_doc = Document(text=answer) if matched_images: answer_doc.chunks = DocumentArray(matched_images[:4]) results.append(answer_doc) except Exception as e: error_msg = f"Query failed: {str(e)}\n{traceback.format_exc()}" print(error_msg) results.append(Document(text=error_msg)) return results @requests(on="/stats") def stats(self, docs: DocumentArray, **kwargs): text_count = sum(1 for m in self.metadata if m["type"] == "text") image_count = sum(1 for m in self.metadata if m["type"] == "image") stats_text = ( f"Index Statistics:\n" f"Total vectors: {self.index.ntotal}\n" f"Text chunks: {text_count}\n" f"Images: {image_count}\n" f"Using Qwen 2.5 & BLIP" ) return DocumentArray([Document(text=stats_text)]) @requests(on="/reset") def reset(self, docs: DocumentArray, **kwargs): try: self.index = faiss.IndexFlatL2(self.dim) self.metadata = [] if os.path.exists(self.index_file): os.remove(self.index_file) if os.path.exists(self.metadata_file): os.remove(self.metadata_file) return DocumentArray([Document(text="Index reset successfully")]) except Exception as e: return DocumentArray([Document(text=f"Reset failed: {str(e)}")])]