Spaces:
Sleeping
Sleeping
Update pipeline.py
Browse files- pipeline.py +21 -20
pipeline.py
CHANGED
|
@@ -13,7 +13,7 @@ from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMMod
|
|
| 13 |
from pydantic import BaseModel, ValidationError, validator
|
| 14 |
from mistralai import Mistral
|
| 15 |
from langchain.prompts import PromptTemplate
|
| 16 |
-
|
| 17 |
# Import chains and tools
|
| 18 |
from classification_chain import get_classification_chain
|
| 19 |
from cleaner_chain import get_cleaner_chain
|
|
@@ -25,6 +25,13 @@ from prompts import classification_prompt, refusal_prompt, tailor_prompt
|
|
| 25 |
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
|
| 26 |
client = Mistral(api_key=mistral_api_key)
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
# Load spaCy model for NER and download it if not already installed
|
| 29 |
def install_spacy_model():
|
| 30 |
try:
|
|
@@ -131,25 +138,19 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
|
|
| 131 |
return vectorstore
|
| 132 |
|
| 133 |
# Function to build RAG chain
|
| 134 |
-
def build_rag_chain(
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
llm=gemini_as_llm,
|
| 148 |
-
chain_type="stuff",
|
| 149 |
-
retriever=retriever,
|
| 150 |
-
return_source_documents=True
|
| 151 |
-
)
|
| 152 |
-
return rag_chain
|
| 153 |
|
| 154 |
# Function to perform web search using DuckDuckGo
|
| 155 |
def do_web_search(query: str) -> str:
|
|
|
|
| 13 |
from pydantic import BaseModel, ValidationError, validator
|
| 14 |
from mistralai import Mistral
|
| 15 |
from langchain.prompts import PromptTemplate
|
| 16 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 17 |
# Import chains and tools
|
| 18 |
from classification_chain import get_classification_chain
|
| 19 |
from cleaner_chain import get_cleaner_chain
|
|
|
|
| 25 |
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
|
| 26 |
client = Mistral(api_key=mistral_api_key)
|
| 27 |
|
| 28 |
+
gemini_llm = ChatGoogleGenerativeAI(
|
| 29 |
+
model="gemini-1.5-pro",
|
| 30 |
+
temperature=0.5,
|
| 31 |
+
max_retries=2,
|
| 32 |
+
google_api_key=os.environ.get("GEMINI_API_KEY"),
|
| 33 |
+
# Additional parameters or safety_settings can be added here if needed
|
| 34 |
+
)
|
| 35 |
# Load spaCy model for NER and download it if not already installed
|
| 36 |
def install_spacy_model():
|
| 37 |
try:
|
|
|
|
| 138 |
return vectorstore
|
| 139 |
|
| 140 |
# Function to build RAG chain
|
| 141 |
+
def build_rag_chain(vectorstore: FAISS) -> RetrievalQA:
|
| 142 |
+
"""Build RAG chain using the Gemini LLM directly without a custom class."""
|
| 143 |
+
try:
|
| 144 |
+
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
|
| 145 |
+
chain = RetrievalQA.from_chain_type(
|
| 146 |
+
llm=gemini_llm, # Directly use the ChatGoogleGenerativeAI instance
|
| 147 |
+
chain_type="stuff",
|
| 148 |
+
retriever=retriever,
|
| 149 |
+
return_source_documents=True
|
| 150 |
+
)
|
| 151 |
+
return chain
|
| 152 |
+
except Exception as e:
|
| 153 |
+
raise RuntimeError(f"Error building RAG chain: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
# Function to perform web search using DuckDuckGo
|
| 156 |
def do_web_search(query: str) -> str:
|