Spaces:
Sleeping
Sleeping
| """RAG pipeline for OSINT investigation assistant""" | |
| from typing import Iterator, Optional, List, Tuple | |
| from .vectorstore import OSINTVectorStore, create_vectorstore | |
| from .llm_client import InferenceProviderClient, create_llm_client | |
| from .prompts import ( | |
| SYSTEM_PROMPT, | |
| INVESTIGATION_PROMPT, | |
| get_investigation_prompt | |
| ) | |
| class OSINTInvestigationPipeline: | |
| """RAG pipeline for generating OSINT investigation methodologies""" | |
| def __init__( | |
| self, | |
| vectorstore: Optional[OSINTVectorStore] = None, | |
| llm_client: Optional[InferenceProviderClient] = None, | |
| retrieval_k: int = 5 | |
| ): | |
| """ | |
| Initialize the RAG pipeline | |
| Args: | |
| vectorstore: Vector store instance (creates default if None) | |
| llm_client: LLM client instance (creates default if None) | |
| retrieval_k: Number of tools to retrieve for context | |
| """ | |
| self.vectorstore = vectorstore or create_vectorstore() | |
| self.llm_client = llm_client or create_llm_client() | |
| self.retrieval_k = retrieval_k | |
| def retrieve_tools(self, query: str, k: Optional[int] = None) -> List: | |
| """ | |
| Retrieve relevant OSINT tools for a query | |
| Args: | |
| query: User's investigation query | |
| k: Number of tools to retrieve (uses default if None) | |
| Returns: | |
| List of relevant tool documents | |
| """ | |
| k = k or self.retrieval_k | |
| return self.vectorstore.similarity_search(query, k=k) | |
| def generate_methodology( | |
| self, | |
| query: str, | |
| stream: bool = False | |
| ) -> str | Iterator[str]: | |
| """ | |
| Generate investigation methodology for a query | |
| Args: | |
| query: User's investigation query | |
| stream: Whether to stream the response | |
| Returns: | |
| Generated methodology (string or iterator) | |
| """ | |
| # Retrieve relevant tools | |
| relevant_tools = self.retrieve_tools(query) | |
| # Format tools for context | |
| context = self.vectorstore.format_tools_for_context(relevant_tools) | |
| # Generate prompt | |
| prompt_template = get_investigation_prompt() | |
| full_prompt = prompt_template.format(query=query, context=context) | |
| # Generate response | |
| if stream: | |
| return self.llm_client.generate_stream( | |
| prompt=full_prompt, | |
| system_prompt=SYSTEM_PROMPT | |
| ) | |
| else: | |
| return self.llm_client.generate( | |
| prompt=full_prompt, | |
| system_prompt=SYSTEM_PROMPT | |
| ) | |
| def chat( | |
| self, | |
| message: str, | |
| history: Optional[List[Tuple[str, str]]] = None, | |
| stream: bool = False | |
| ) -> str | Iterator[str]: | |
| """ | |
| Handle a chat message with conversation history | |
| Args: | |
| message: User's message | |
| history: Conversation history as list of (user_msg, assistant_msg) tuples | |
| stream: Whether to stream the response | |
| Returns: | |
| Generated response (string or iterator) | |
| """ | |
| # For now, treat each message as a new investigation query | |
| # In the future, could implement follow-up handling | |
| return self.generate_methodology(message, stream=stream) | |
| def get_tool_recommendations( | |
| self, | |
| query: str, | |
| k: int = 5 | |
| ) -> List[dict]: | |
| """ | |
| Get tool recommendations with metadata | |
| Args: | |
| query: Investigation query | |
| k: Number of tools to recommend | |
| Returns: | |
| List of tool dictionaries with metadata | |
| """ | |
| docs = self.retrieve_tools(query, k=k) | |
| tools = [] | |
| for doc in docs: | |
| tool = { | |
| "name": doc.metadata.get("name", "Unknown"), | |
| "category": doc.metadata.get("category", "N/A"), | |
| "cost": doc.metadata.get("cost", "N/A"), | |
| "url": doc.metadata.get("url", "N/A"), | |
| "description": doc.page_content, | |
| "details": doc.metadata.get("details", "N/A") | |
| } | |
| tools.append(tool) | |
| return tools | |
| def search_tools_by_category( | |
| self, | |
| category: str, | |
| k: int = 10 | |
| ) -> List[dict]: | |
| """ | |
| Search tools by category | |
| Args: | |
| category: Tool category (e.g., "Archiving", "Social Media") | |
| k: Number of tools to return | |
| Returns: | |
| List of tool dictionaries | |
| """ | |
| docs = self.vectorstore.similarity_search( | |
| query=category, | |
| k=k, | |
| filter_category=category | |
| ) | |
| tools = [] | |
| for doc in docs: | |
| tool = { | |
| "name": doc.metadata.get("name", "Unknown"), | |
| "category": doc.metadata.get("category", "N/A"), | |
| "cost": doc.metadata.get("cost", "N/A"), | |
| "url": doc.metadata.get("url", "N/A"), | |
| "description": doc.page_content | |
| } | |
| tools.append(tool) | |
| return tools | |
| def create_pipeline( | |
| retrieval_k: int = 5, | |
| model: str = "meta-llama/Llama-3.1-8B-Instruct", | |
| temperature: float = 0.2 | |
| ) -> OSINTInvestigationPipeline: | |
| """ | |
| Factory function to create a configured RAG pipeline | |
| Args: | |
| retrieval_k: Number of tools to retrieve | |
| model: LLM model identifier | |
| temperature: LLM temperature | |
| Returns: | |
| Configured OSINTInvestigationPipeline | |
| """ | |
| vectorstore = create_vectorstore() | |
| llm_client = create_llm_client(model=model, temperature=temperature) | |
| return OSINTInvestigationPipeline( | |
| vectorstore=vectorstore, | |
| llm_client=llm_client, | |
| retrieval_k=retrieval_k | |
| ) | |