Spaces:
Sleeping
Sleeping
| """Supabase PGVector connection and retrieval functionality""" | |
| import os | |
| from typing import List, Dict, Any, Optional | |
| from supabase import create_client, Client | |
| from huggingface_hub import InferenceClient | |
| class Document: | |
| """Simple document class to match LangChain interface""" | |
| def __init__(self, page_content: str, metadata: dict): | |
| self.page_content = page_content | |
| self.metadata = metadata | |
| class OSINTVectorStore: | |
| """Manages connection to Supabase PGVector database with OSINT tools""" | |
| def __init__( | |
| self, | |
| supabase_url: Optional[str] = None, | |
| supabase_key: Optional[str] = None, | |
| hf_token: Optional[str] = None, | |
| embedding_model: str = "sentence-transformers/all-mpnet-base-v2" | |
| ): | |
| """ | |
| Initialize the vector store connection | |
| Args: | |
| supabase_url: Supabase project URL (defaults to SUPABASE_URL env var) | |
| supabase_key: Supabase anon key (defaults to SUPABASE_KEY env var) | |
| hf_token: HuggingFace API token (defaults to HF_TOKEN env var) | |
| embedding_model: HuggingFace model for embeddings | |
| """ | |
| # Get credentials from parameters or environment | |
| self.supabase_url = supabase_url or os.getenv("SUPABASE_URL") | |
| self.supabase_key = supabase_key or os.getenv("SUPABASE_KEY") | |
| self.hf_token = hf_token or os.getenv("HF_TOKEN") | |
| if not self.supabase_url or not self.supabase_key: | |
| raise ValueError("SUPABASE_URL and SUPABASE_KEY environment variables must be set") | |
| if not self.hf_token: | |
| raise ValueError("HF_TOKEN environment variable must be set") | |
| # Initialize Supabase client | |
| self.supabase: Client = create_client(self.supabase_url, self.supabase_key) | |
| # Initialize HuggingFace Inference client for embeddings | |
| self.embedding_model = embedding_model | |
| self.hf_client = InferenceClient(token=self.hf_token) | |
| def _generate_embedding(self, text: str) -> List[float]: | |
| """ | |
| Generate embedding for text using HuggingFace Inference API | |
| Args: | |
| text: Text to embed | |
| Returns: | |
| List of floats representing the embedding vector (768 dimensions) | |
| """ | |
| try: | |
| # Use feature extraction to get embeddings | |
| # Note: We rely on the API's default model which returns 768-dim embeddings | |
| result = self.hf_client.feature_extraction(text=text) | |
| # Convert to list (handles numpy arrays and nested lists) | |
| import numpy as np | |
| # If it's a numpy array, convert to list | |
| if isinstance(result, np.ndarray): | |
| if result.ndim > 1: | |
| result = result[0] # Take first row if 2D | |
| return result.tolist() | |
| # If it's a nested list, flatten if needed | |
| if isinstance(result, list) and len(result) > 0: | |
| if isinstance(result[0], list): | |
| return result[0] # Take first embedding if batched | |
| # Handle nested numpy arrays in list | |
| if isinstance(result[0], np.ndarray): | |
| return result[0].tolist() | |
| return result | |
| return result | |
| except Exception as e: | |
| raise Exception(f"Error generating embedding: {str(e)}") | |
| def similarity_search( | |
| self, | |
| query: str, | |
| k: int = 5, | |
| filter_category: Optional[str] = None, | |
| filter_cost: Optional[str] = None, | |
| match_threshold: float = 0.5 | |
| ) -> List[Document]: | |
| """ | |
| Perform similarity search on the OSINT tools database | |
| Args: | |
| query: Search query | |
| k: Number of results to return | |
| filter_category: Optional category filter | |
| filter_cost: Optional cost filter (e.g., 'Free', 'Paid') | |
| match_threshold: Minimum similarity threshold (0.0 to 1.0) | |
| Returns: | |
| List of Document objects with relevant OSINT tools | |
| """ | |
| # Generate embedding for query | |
| query_embedding = self._generate_embedding(query) | |
| # Call RPC function | |
| try: | |
| response = self.supabase.rpc( | |
| 'match_bellingcat_tools', | |
| { | |
| 'query_embedding': query_embedding, | |
| 'match_threshold': match_threshold, | |
| 'match_count': k, | |
| 'filter_category': filter_category, | |
| 'filter_cost': filter_cost | |
| } | |
| ).execute() | |
| # Convert results to Document objects | |
| documents = [] | |
| for item in response.data: | |
| doc = Document( | |
| page_content=item.get('content', ''), | |
| metadata={ | |
| 'id': item.get('id'), | |
| 'name': item.get('name'), | |
| 'category': item.get('category'), | |
| 'url': item.get('url'), | |
| 'cost': item.get('cost'), | |
| 'details': item.get('details'), | |
| 'similarity': item.get('similarity') | |
| } | |
| ) | |
| documents.append(doc) | |
| return documents | |
| except Exception as e: | |
| raise Exception(f"Error performing similarity search: {str(e)}") | |
| def similarity_search_with_score( | |
| self, | |
| query: str, | |
| k: int = 5 | |
| ) -> List[tuple]: | |
| """ | |
| Perform similarity search and return documents with relevance scores | |
| Args: | |
| query: Search query | |
| k: Number of results to return | |
| Returns: | |
| List of tuples (Document, score) | |
| """ | |
| # Generate embedding for query | |
| query_embedding = self._generate_embedding(query) | |
| # Call RPC function | |
| try: | |
| response = self.supabase.rpc( | |
| 'match_bellingcat_tools', | |
| { | |
| 'query_embedding': query_embedding, | |
| 'match_threshold': 0.0, # Get all matches | |
| 'match_count': k, | |
| 'filter_category': None, | |
| 'filter_cost': None | |
| } | |
| ).execute() | |
| # Convert results to Document objects with scores | |
| results = [] | |
| for item in response.data: | |
| doc = Document( | |
| page_content=item.get('content', ''), | |
| metadata={ | |
| 'id': item.get('id'), | |
| 'name': item.get('name'), | |
| 'category': item.get('category'), | |
| 'url': item.get('url'), | |
| 'cost': item.get('cost'), | |
| 'details': item.get('details') | |
| } | |
| ) | |
| score = item.get('similarity', 0.0) | |
| results.append((doc, score)) | |
| return results | |
| except Exception as e: | |
| raise Exception(f"Error performing similarity search: {str(e)}") | |
| def get_retriever(self, k: int = 5): | |
| """ | |
| Get a retriever-like object for LangChain compatibility | |
| Args: | |
| k: Number of results to return | |
| Returns: | |
| Simple retriever object with get_relevant_documents method | |
| """ | |
| class SimpleRetriever: | |
| def __init__(self, vectorstore, k): | |
| self.vectorstore = vectorstore | |
| self.k = k | |
| def get_relevant_documents(self, query: str) -> List[Document]: | |
| return self.vectorstore.similarity_search(query, k=self.k) | |
| return SimpleRetriever(self, k) | |
| def format_tools_for_context(self, documents: List[Document]) -> str: | |
| """ | |
| Format retrieved tools for inclusion in LLM context | |
| Args: | |
| documents: List of retrieved Document objects | |
| Returns: | |
| Formatted string with tool information | |
| """ | |
| formatted_tools = [] | |
| for i, doc in enumerate(documents, 1): | |
| metadata = doc.metadata | |
| tool_info = f""" | |
| Tool {i}: {metadata.get('name', 'Unknown')} | |
| Category: {metadata.get('category', 'N/A')} | |
| Cost: {metadata.get('cost', 'N/A')} | |
| URL: {metadata.get('url', 'N/A')} | |
| Description: {doc.page_content} | |
| Details: {metadata.get('details', 'N/A')} | |
| """ | |
| formatted_tools.append(tool_info.strip()) | |
| return "\n\n---\n\n".join(formatted_tools) | |
| def get_tool_categories(self) -> List[str]: | |
| """Get list of available tool categories from database""" | |
| try: | |
| response = self.supabase.table('bellingcat_tools')\ | |
| .select('category')\ | |
| .execute() | |
| # Extract unique categories | |
| categories = set() | |
| for item in response.data: | |
| if item.get('category'): | |
| categories.add(item['category']) | |
| return sorted(list(categories)) | |
| except Exception as e: | |
| # Return common categories as fallback | |
| return [ | |
| "Archiving", | |
| "Social Media", | |
| "Geolocation", | |
| "Image Analysis", | |
| "Domain Investigation", | |
| "Network Analysis", | |
| "Data Extraction", | |
| "Verification" | |
| ] | |
| def create_vectorstore() -> OSINTVectorStore: | |
| """Factory function to create and return a configured vector store""" | |
| return OSINTVectorStore() | |