Spaces:
Running
Running
| """Supabase PGVector connection and retrieval functionality for graphics/design documents""" | |
| import os | |
| from typing import List, Dict, Any, Optional | |
| from supabase import create_client, Client | |
| from huggingface_hub import InferenceClient | |
| class Document: | |
| """Simple document class to match LangChain interface""" | |
| def __init__(self, page_content: str, metadata: dict): | |
| self.page_content = page_content | |
| self.metadata = metadata | |
| class GraphicsVectorStore: | |
| """Manages connection to Supabase PGVector database with graphics/design document embeddings""" | |
| def __init__( | |
| self, | |
| supabase_url: Optional[str] = None, | |
| supabase_key: Optional[str] = None, | |
| hf_token: Optional[str] = None, | |
| jina_api_key: Optional[str] = None, | |
| embedding_model: str = "jina-clip-v2" | |
| ): | |
| """ | |
| Initialize the vector store connection | |
| Args: | |
| supabase_url: Supabase project URL (defaults to SUPABASE_URL env var) | |
| supabase_key: Supabase anon key (defaults to SUPABASE_KEY env var) | |
| hf_token: HuggingFace API token (defaults to HF_TOKEN env var) | |
| jina_api_key: Jina AI API key (defaults to JINA_API_KEY env var, required for Jina models) | |
| embedding_model: Embedding model to use (default: jinaai/jina-clip-v2) | |
| """ | |
| # Get credentials from parameters or environment | |
| self.supabase_url = supabase_url or os.getenv("SUPABASE_URL") | |
| self.supabase_key = supabase_key or os.getenv("SUPABASE_KEY") | |
| self.hf_token = hf_token or os.getenv("HF_TOKEN") | |
| self.jina_api_key = jina_api_key or os.getenv("JINA_API_KEY") | |
| if not self.supabase_url or not self.supabase_key: | |
| raise ValueError("SUPABASE_URL and SUPABASE_KEY environment variables must be set") | |
| # Check for appropriate API key based on model | |
| self.embedding_model = embedding_model | |
| if "jina" in self.embedding_model.lower(): | |
| if not self.jina_api_key: | |
| raise ValueError("JINA_API_KEY environment variable must be set for Jina models") | |
| else: | |
| if not self.hf_token: | |
| raise ValueError("HF_TOKEN environment variable must be set for HuggingFace models") | |
| # Initialize Supabase client | |
| self.supabase: Client = create_client(self.supabase_url, self.supabase_key) | |
| # Initialize HuggingFace Inference client for embeddings (if using HF models) | |
| if self.hf_token: | |
| self.hf_client = InferenceClient(token=self.hf_token) | |
| def _generate_embedding(self, text: str) -> List[float]: | |
| """ | |
| Generate embedding for text using HuggingFace Inference API | |
| Args: | |
| text: Text to embed | |
| Returns: | |
| List of floats representing the embedding vector (1024 dimensions) | |
| """ | |
| try: | |
| # For Jina-CLIP-v2, use the Jina AI Embeddings API | |
| import requests | |
| import numpy as np | |
| # Jina AI uses their own API endpoint | |
| api_url = "https://api.jina.ai/v1/embeddings" | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {self.jina_api_key}" | |
| } | |
| payload = { | |
| "model": self.embedding_model, | |
| "input": [text] | |
| } | |
| response = requests.post(api_url, headers=headers, json=payload, timeout=30) | |
| if response.status_code != 200: | |
| raise Exception(f"API returned status {response.status_code}: {response.text}") | |
| result = response.json() | |
| # Jina API returns embeddings in data array | |
| if isinstance(result, dict) and 'data' in result: | |
| embedding = result['data'][0]['embedding'] | |
| return embedding | |
| # Fallback to standard response parsing | |
| result = result if not isinstance(result, dict) else result.get('embeddings', result) | |
| # Convert to list (handles numpy arrays and nested lists) | |
| # If it's a numpy array, convert to list | |
| if isinstance(result, np.ndarray): | |
| if result.ndim > 1: | |
| result = result[0] # Take first row if 2D | |
| return result.tolist() | |
| # If it's a nested list, flatten if needed | |
| if isinstance(result, list) and len(result) > 0: | |
| if isinstance(result[0], list): | |
| return result[0] # Take first embedding if batched | |
| # Handle nested numpy arrays in list | |
| if isinstance(result[0], np.ndarray): | |
| return result[0].tolist() | |
| return result | |
| return result | |
| except Exception as e: | |
| raise Exception(f"Error generating embedding with {self.embedding_model}: {str(e)}") | |
| def similarity_search( | |
| self, | |
| query: str, | |
| k: int = 5, | |
| match_threshold: float = 0.3 | |
| ) -> List[Document]: | |
| """ | |
| Perform similarity search on the graphics/design document database | |
| Args: | |
| query: Search query | |
| k: Number of results to return | |
| match_threshold: Minimum similarity threshold (0.0 to 1.0) | |
| Returns: | |
| List of Document objects with relevant document chunks | |
| """ | |
| # Generate embedding for query | |
| query_embedding = self._generate_embedding(query) | |
| # Call RPC function | |
| try: | |
| response = self.supabase.rpc( | |
| 'match_documents', | |
| { | |
| 'query_embedding': query_embedding, | |
| 'match_threshold': match_threshold, | |
| 'match_count': k | |
| } | |
| ).execute() | |
| # Convert results to Document objects | |
| documents = [] | |
| for item in response.data: | |
| # Handle None chunk_text | |
| chunk_text = item.get('chunk_text') or '' | |
| doc = Document( | |
| page_content=chunk_text, | |
| metadata={ | |
| 'id': item.get('id'), | |
| 'source_type': item.get('source_type'), | |
| 'source_id': item.get('source_id'), | |
| 'title': item.get('title', ''), | |
| 'content_type': item.get('content_type'), | |
| 'chunk_index': item.get('chunk_index'), | |
| 'page_number': item.get('page_number'), | |
| 'word_count': item.get('word_count'), | |
| 'metadata': item.get('metadata', {}), | |
| 'similarity': item.get('similarity') | |
| } | |
| ) | |
| documents.append(doc) | |
| return documents | |
| except Exception as e: | |
| raise Exception(f"Error performing similarity search: {str(e)}") | |
| def similarity_search_with_score( | |
| self, | |
| query: str, | |
| k: int = 5 | |
| ) -> List[tuple]: | |
| """ | |
| Perform similarity search and return documents with relevance scores | |
| Args: | |
| query: Search query | |
| k: Number of results to return | |
| Returns: | |
| List of tuples (Document, score) | |
| """ | |
| # Generate embedding for query | |
| query_embedding = self._generate_embedding(query) | |
| # Call RPC function | |
| try: | |
| response = self.supabase.rpc( | |
| 'match_documents', | |
| { | |
| 'query_embedding': query_embedding, | |
| 'match_threshold': 0.0, # Get all matches | |
| 'match_count': k | |
| } | |
| ).execute() | |
| # Convert results to Document objects with scores | |
| results = [] | |
| for item in response.data: | |
| # Handle None chunk_text | |
| chunk_text = item.get('chunk_text') or '' | |
| doc = Document( | |
| page_content=chunk_text, | |
| metadata={ | |
| 'id': item.get('id'), | |
| 'source_type': item.get('source_type'), | |
| 'source_id': item.get('source_id'), | |
| 'title': item.get('title', ''), | |
| 'content_type': item.get('content_type'), | |
| 'chunk_index': item.get('chunk_index'), | |
| 'page_number': item.get('page_number'), | |
| 'word_count': item.get('word_count'), | |
| 'metadata': item.get('metadata', {}) | |
| } | |
| ) | |
| score = item.get('similarity', 0.0) | |
| results.append((doc, score)) | |
| return results | |
| except Exception as e: | |
| raise Exception(f"Error performing similarity search: {str(e)}") | |
| def get_retriever(self, k: int = 5): | |
| """ | |
| Get a retriever-like object for LangChain compatibility | |
| Args: | |
| k: Number of results to return | |
| Returns: | |
| Simple retriever object with get_relevant_documents method | |
| """ | |
| class SimpleRetriever: | |
| def __init__(self, vectorstore, k): | |
| self.vectorstore = vectorstore | |
| self.k = k | |
| def get_relevant_documents(self, query: str) -> List[Document]: | |
| return self.vectorstore.similarity_search(query, k=self.k) | |
| return SimpleRetriever(self, k) | |
| def format_documents_for_context(self, documents: List[Document]) -> str: | |
| """ | |
| Format retrieved documents for inclusion in LLM context | |
| Args: | |
| documents: List of retrieved Document objects | |
| Returns: | |
| Formatted string with document information | |
| """ | |
| formatted_docs = [] | |
| for i, doc in enumerate(documents, 1): | |
| metadata = doc.metadata | |
| source_info = f"Source: {metadata.get('source_id', 'Unknown')}" | |
| if metadata.get('page_number'): | |
| source_info += f" (Page {metadata.get('page_number')})" | |
| doc_info = f""" | |
| Document {i}: {source_info} | |
| Type: {metadata.get('source_type', 'N/A')} | Content: {metadata.get('content_type', 'text')} | |
| {doc.page_content} | |
| """ | |
| formatted_docs.append(doc_info.strip()) | |
| return "\n\n---\n\n".join(formatted_docs) | |
| def get_source_types(self) -> List[str]: | |
| """Get list of available source types from database""" | |
| try: | |
| response = self.supabase.table('document_embeddings')\ | |
| .select('source_type')\ | |
| .execute() | |
| # Extract unique source types | |
| source_types = set() | |
| for item in response.data: | |
| if item.get('source_type'): | |
| source_types.add(item['source_type']) | |
| return sorted(list(source_types)) | |
| except Exception as e: | |
| # Return common source types as fallback | |
| return [ | |
| "pdf", | |
| "url", | |
| "image" | |
| ] | |
| def create_vectorstore() -> GraphicsVectorStore: | |
| """Factory function to create and return a configured vector store""" | |
| return GraphicsVectorStore() | |