graphics-llm / src /rag_pipeline.py
Tom
Update to Jina-CLIP-v2 embeddings and rebrand to Viz LLM
2d671a2
raw
history blame
4.84 kB
"""RAG pipeline for Graphics Guide / Design Assistant"""
from typing import Iterator, Optional, List, Tuple
from .vectorstore import GraphicsVectorStore, create_vectorstore
from .llm_client import InferenceProviderClient, create_llm_client
from .prompts import (
SYSTEM_PROMPT,
DESIGN_PROMPT,
get_design_prompt
)
class GraphicsDesignPipeline:
"""RAG pipeline for generating graphics and design recommendations"""
def __init__(
self,
vectorstore: Optional[GraphicsVectorStore] = None,
llm_client: Optional[InferenceProviderClient] = None,
retrieval_k: int = 5
):
"""
Initialize the RAG pipeline
Args:
vectorstore: Vector store instance (creates default if None)
llm_client: LLM client instance (creates default if None)
retrieval_k: Number of document chunks to retrieve for context
"""
self.vectorstore = vectorstore or create_vectorstore()
self.llm_client = llm_client or create_llm_client()
self.retrieval_k = retrieval_k
def retrieve_documents(self, query: str, k: Optional[int] = None) -> List:
"""
Retrieve relevant document chunks for a query
Args:
query: User's design query
k: Number of documents to retrieve (uses default if None)
Returns:
List of relevant document chunks
"""
k = k or self.retrieval_k
return self.vectorstore.similarity_search(query, k=k)
def generate_recommendations(
self,
query: str,
stream: bool = False
) -> str | Iterator[str]:
"""
Generate design recommendations for a query
Args:
query: User's design query
stream: Whether to stream the response
Returns:
Generated recommendations (string or iterator)
"""
# Retrieve relevant documents
relevant_docs = self.retrieve_documents(query)
# Format documents for context
context = self.vectorstore.format_documents_for_context(relevant_docs)
# Generate prompt
prompt_template = get_design_prompt()
full_prompt = prompt_template.format(query=query, context=context)
# Generate response
if stream:
return self.llm_client.generate_stream(
prompt=full_prompt,
system_prompt=SYSTEM_PROMPT
)
else:
return self.llm_client.generate(
prompt=full_prompt,
system_prompt=SYSTEM_PROMPT
)
def chat(
self,
message: str,
history: Optional[List[Tuple[str, str]]] = None,
stream: bool = False
) -> str | Iterator[str]:
"""
Handle a chat message with conversation history
Args:
message: User's message
history: Conversation history as list of (user_msg, assistant_msg) tuples
stream: Whether to stream the response
Returns:
Generated response (string or iterator)
"""
# For now, treat each message as a new design query
# In the future, could implement follow-up handling
return self.generate_recommendations(message, stream=stream)
def get_relevant_examples(
self,
query: str,
k: int = 5
) -> List[dict]:
"""
Get relevant examples and knowledge with metadata
Args:
query: Design query
k: Number of examples to recommend
Returns:
List of document dictionaries with metadata
"""
docs = self.retrieve_documents(query, k=k)
examples = []
for doc in docs:
example = {
"source": doc.metadata.get("source_id", "Unknown"),
"source_type": doc.metadata.get("source_type", "N/A"),
"page": doc.metadata.get("page_number"),
"content": doc.page_content,
"similarity": doc.metadata.get("similarity")
}
examples.append(example)
return examples
def create_pipeline(
retrieval_k: int = 5,
model: str = "meta-llama/Llama-3.1-8B-Instruct",
temperature: float = 0.2
) -> GraphicsDesignPipeline:
"""
Factory function to create a configured RAG pipeline
Args:
retrieval_k: Number of documents to retrieve
model: LLM model identifier
temperature: LLM temperature
Returns:
Configured GraphicsDesignPipeline
"""
vectorstore = create_vectorstore()
llm_client = create_llm_client(model=model, temperature=temperature)
return GraphicsDesignPipeline(
vectorstore=vectorstore,
llm_client=llm_client,
retrieval_k=retrieval_k
)