Spaces:
Running
Running
| """RAG pipeline for Graphics Guide / Design Assistant""" | |
| from typing import Iterator, Optional, List, Tuple | |
| from .vectorstore import GraphicsVectorStore, create_vectorstore | |
| from .llm_client import InferenceProviderClient, create_llm_client | |
| from .prompts import ( | |
| SYSTEM_PROMPT, | |
| DESIGN_PROMPT, | |
| get_design_prompt | |
| ) | |
| class GraphicsDesignPipeline: | |
| """RAG pipeline for generating graphics and design recommendations""" | |
| def __init__( | |
| self, | |
| vectorstore: Optional[GraphicsVectorStore] = None, | |
| llm_client: Optional[InferenceProviderClient] = None, | |
| retrieval_k: int = 5 | |
| ): | |
| """ | |
| Initialize the RAG pipeline | |
| Args: | |
| vectorstore: Vector store instance (creates default if None) | |
| llm_client: LLM client instance (creates default if None) | |
| retrieval_k: Number of document chunks to retrieve for context | |
| """ | |
| self.vectorstore = vectorstore or create_vectorstore() | |
| self.llm_client = llm_client or create_llm_client() | |
| self.retrieval_k = retrieval_k | |
| def retrieve_documents(self, query: str, k: Optional[int] = None) -> List: | |
| """ | |
| Retrieve relevant document chunks for a query | |
| Args: | |
| query: User's design query | |
| k: Number of documents to retrieve (uses default if None) | |
| Returns: | |
| List of relevant document chunks | |
| """ | |
| k = k or self.retrieval_k | |
| return self.vectorstore.similarity_search(query, k=k) | |
| def generate_recommendations( | |
| self, | |
| query: str, | |
| stream: bool = False | |
| ) -> str | Iterator[str]: | |
| """ | |
| Generate design recommendations for a query | |
| Args: | |
| query: User's design query | |
| stream: Whether to stream the response | |
| Returns: | |
| Generated recommendations (string or iterator) | |
| """ | |
| # Retrieve relevant documents | |
| relevant_docs = self.retrieve_documents(query) | |
| # Format documents for context | |
| context = self.vectorstore.format_documents_for_context(relevant_docs) | |
| # Generate prompt | |
| prompt_template = get_design_prompt() | |
| full_prompt = prompt_template.format(query=query, context=context) | |
| # Generate response | |
| if stream: | |
| return self.llm_client.generate_stream( | |
| prompt=full_prompt, | |
| system_prompt=SYSTEM_PROMPT | |
| ) | |
| else: | |
| return self.llm_client.generate( | |
| prompt=full_prompt, | |
| system_prompt=SYSTEM_PROMPT | |
| ) | |
| def chat( | |
| self, | |
| message: str, | |
| history: Optional[List[Tuple[str, str]]] = None, | |
| stream: bool = False | |
| ) -> str | Iterator[str]: | |
| """ | |
| Handle a chat message with conversation history | |
| Args: | |
| message: User's message | |
| history: Conversation history as list of (user_msg, assistant_msg) tuples | |
| stream: Whether to stream the response | |
| Returns: | |
| Generated response (string or iterator) | |
| """ | |
| # For now, treat each message as a new design query | |
| # In the future, could implement follow-up handling | |
| return self.generate_recommendations(message, stream=stream) | |
| def get_relevant_examples( | |
| self, | |
| query: str, | |
| k: int = 5 | |
| ) -> List[dict]: | |
| """ | |
| Get relevant examples and knowledge with metadata | |
| Args: | |
| query: Design query | |
| k: Number of examples to recommend | |
| Returns: | |
| List of document dictionaries with metadata | |
| """ | |
| docs = self.retrieve_documents(query, k=k) | |
| examples = [] | |
| for doc in docs: | |
| example = { | |
| "source": doc.metadata.get("source_id", "Unknown"), | |
| "source_type": doc.metadata.get("source_type", "N/A"), | |
| "page": doc.metadata.get("page_number"), | |
| "content": doc.page_content, | |
| "similarity": doc.metadata.get("similarity") | |
| } | |
| examples.append(example) | |
| return examples | |
| def create_pipeline( | |
| retrieval_k: int = 5, | |
| model: str = "meta-llama/Llama-3.1-8B-Instruct", | |
| temperature: float = 0.2 | |
| ) -> GraphicsDesignPipeline: | |
| """ | |
| Factory function to create a configured RAG pipeline | |
| Args: | |
| retrieval_k: Number of documents to retrieve | |
| model: LLM model identifier | |
| temperature: LLM temperature | |
| Returns: | |
| Configured GraphicsDesignPipeline | |
| """ | |
| vectorstore = create_vectorstore() | |
| llm_client = create_llm_client(model=model, temperature=temperature) | |
| return GraphicsDesignPipeline( | |
| vectorstore=vectorstore, | |
| llm_client=llm_client, | |
| retrieval_k=retrieval_k | |
| ) | |