osint-llm / src /rag_pipeline.py
Tom
Add complete RAG-powered OSINT investigation assistant
6466c00
"""RAG pipeline for OSINT investigation assistant"""
from typing import Iterator, Optional, List, Tuple
from .vectorstore import OSINTVectorStore, create_vectorstore
from .llm_client import InferenceProviderClient, create_llm_client
from .prompts import (
SYSTEM_PROMPT,
INVESTIGATION_PROMPT,
get_investigation_prompt
)
class OSINTInvestigationPipeline:
"""RAG pipeline for generating OSINT investigation methodologies"""
def __init__(
self,
vectorstore: Optional[OSINTVectorStore] = None,
llm_client: Optional[InferenceProviderClient] = None,
retrieval_k: int = 5
):
"""
Initialize the RAG pipeline
Args:
vectorstore: Vector store instance (creates default if None)
llm_client: LLM client instance (creates default if None)
retrieval_k: Number of tools to retrieve for context
"""
self.vectorstore = vectorstore or create_vectorstore()
self.llm_client = llm_client or create_llm_client()
self.retrieval_k = retrieval_k
def retrieve_tools(self, query: str, k: Optional[int] = None) -> List:
"""
Retrieve relevant OSINT tools for a query
Args:
query: User's investigation query
k: Number of tools to retrieve (uses default if None)
Returns:
List of relevant tool documents
"""
k = k or self.retrieval_k
return self.vectorstore.similarity_search(query, k=k)
def generate_methodology(
self,
query: str,
stream: bool = False
) -> str | Iterator[str]:
"""
Generate investigation methodology for a query
Args:
query: User's investigation query
stream: Whether to stream the response
Returns:
Generated methodology (string or iterator)
"""
# Retrieve relevant tools
relevant_tools = self.retrieve_tools(query)
# Format tools for context
context = self.vectorstore.format_tools_for_context(relevant_tools)
# Generate prompt
prompt_template = get_investigation_prompt()
full_prompt = prompt_template.format(query=query, context=context)
# Generate response
if stream:
return self.llm_client.generate_stream(
prompt=full_prompt,
system_prompt=SYSTEM_PROMPT
)
else:
return self.llm_client.generate(
prompt=full_prompt,
system_prompt=SYSTEM_PROMPT
)
def chat(
self,
message: str,
history: Optional[List[Tuple[str, str]]] = None,
stream: bool = False
) -> str | Iterator[str]:
"""
Handle a chat message with conversation history
Args:
message: User's message
history: Conversation history as list of (user_msg, assistant_msg) tuples
stream: Whether to stream the response
Returns:
Generated response (string or iterator)
"""
# For now, treat each message as a new investigation query
# In the future, could implement follow-up handling
return self.generate_methodology(message, stream=stream)
def get_tool_recommendations(
self,
query: str,
k: int = 5
) -> List[dict]:
"""
Get tool recommendations with metadata
Args:
query: Investigation query
k: Number of tools to recommend
Returns:
List of tool dictionaries with metadata
"""
docs = self.retrieve_tools(query, k=k)
tools = []
for doc in docs:
tool = {
"name": doc.metadata.get("name", "Unknown"),
"category": doc.metadata.get("category", "N/A"),
"cost": doc.metadata.get("cost", "N/A"),
"url": doc.metadata.get("url", "N/A"),
"description": doc.page_content,
"details": doc.metadata.get("details", "N/A")
}
tools.append(tool)
return tools
def search_tools_by_category(
self,
category: str,
k: int = 10
) -> List[dict]:
"""
Search tools by category
Args:
category: Tool category (e.g., "Archiving", "Social Media")
k: Number of tools to return
Returns:
List of tool dictionaries
"""
docs = self.vectorstore.similarity_search(
query=category,
k=k,
filter_category=category
)
tools = []
for doc in docs:
tool = {
"name": doc.metadata.get("name", "Unknown"),
"category": doc.metadata.get("category", "N/A"),
"cost": doc.metadata.get("cost", "N/A"),
"url": doc.metadata.get("url", "N/A"),
"description": doc.page_content
}
tools.append(tool)
return tools
def create_pipeline(
retrieval_k: int = 5,
model: str = "meta-llama/Llama-3.1-8B-Instruct",
temperature: float = 0.2
) -> OSINTInvestigationPipeline:
"""
Factory function to create a configured RAG pipeline
Args:
retrieval_k: Number of tools to retrieve
model: LLM model identifier
temperature: LLM temperature
Returns:
Configured OSINTInvestigationPipeline
"""
vectorstore = create_vectorstore()
llm_client = create_llm_client(model=model, temperature=temperature)
return OSINTInvestigationPipeline(
vectorstore=vectorstore,
llm_client=llm_client,
retrieval_k=retrieval_k
)