Spaces:
Sleeping
Sleeping
Update pipeline.py
Browse files- pipeline.py +21 -68
pipeline.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
import getpass
|
|
|
|
|
|
|
| 3 |
import spacy # Import spaCy for NER functionality
|
| 4 |
import pandas as pd
|
| 5 |
from typing import Optional
|
|
@@ -10,33 +12,18 @@ from langchain.chains import RetrievalQA
|
|
| 10 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
| 11 |
import subprocess # Import subprocess to run shell commands
|
| 12 |
from langchain.llms.base import LLM # Import LLM
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
#
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
# 1) Environment: set up keys if missing
|
| 22 |
-
if not os.environ.get("GEMINI_API_KEY"):
|
| 23 |
-
os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
|
| 24 |
-
if not os.environ.get("GROQ_API_KEY"):
|
| 25 |
-
os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
|
| 26 |
-
if not os.environ.get("MISTRAL_API_KEY"):
|
| 27 |
-
os.environ["MISTRAL_API_KEY"] = getpass.getpass("Enter your Mistral API Key: ")
|
| 28 |
-
|
| 29 |
-
# Initialize Mistral client
|
| 30 |
-
mistral_client = Mistral(api_key=os.environ["MISTRAL_API_KEY"])
|
| 31 |
-
|
| 32 |
-
# 2) Load spaCy model for NER and download the spaCy model if not already installed
|
| 33 |
def install_spacy_model():
|
| 34 |
try:
|
| 35 |
-
# Check if the model is already installed
|
| 36 |
spacy.load("en_core_web_sm")
|
| 37 |
print("spaCy model 'en_core_web_sm' is already installed.")
|
| 38 |
except OSError:
|
| 39 |
-
# If model is not installed, download it using subprocess
|
| 40 |
print("Downloading spaCy model 'en_core_web_sm'...")
|
| 41 |
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
|
| 42 |
print("spaCy model 'en_core_web_sm' downloaded successfully.")
|
|
@@ -47,46 +34,16 @@ install_spacy_model()
|
|
| 47 |
# Load the spaCy model globally
|
| 48 |
nlp = spacy.load("en_core_web_sm")
|
| 49 |
|
| 50 |
-
# Function to
|
| 51 |
-
def extract_main_topic(query: str) -> str:
|
| 52 |
-
"""
|
| 53 |
-
Extracts the main topic from the user's query using spaCy's NER.
|
| 54 |
-
Returns the first named entity or noun found in the query.
|
| 55 |
-
"""
|
| 56 |
-
doc = nlp(query) # Use the globally loaded spaCy model
|
| 57 |
-
|
| 58 |
-
# Try to extract the main topic as a named entity (person, product, etc.)
|
| 59 |
-
main_topic = None
|
| 60 |
-
for ent in doc.ents:
|
| 61 |
-
# Filter for specific entity types (you can adjust this based on your needs)
|
| 62 |
-
if ent.label_ in ["ORG", "PRODUCT", "PERSON", "GPE", "TIME"]: # Add more entity labels as needed
|
| 63 |
-
main_topic = ent.text
|
| 64 |
-
break
|
| 65 |
-
|
| 66 |
-
# If no named entity found, fallback to extracting the first noun or proper noun
|
| 67 |
-
if not main_topic:
|
| 68 |
-
for token in doc:
|
| 69 |
-
if token.pos_ in ["NOUN", "PROPN"]: # Extract first noun or proper noun
|
| 70 |
-
main_topic = token.text
|
| 71 |
-
break
|
| 72 |
-
|
| 73 |
-
# Return the extracted topic or a fallback value if no topic is found
|
| 74 |
-
return main_topic if main_topic else "this topic"
|
| 75 |
-
|
| 76 |
-
# 3) Function to moderate text using Mistral moderation API
|
| 77 |
def moderate_text(query: str) -> str:
|
| 78 |
"""
|
| 79 |
-
Classifies the query as harmful or not using Mistral Moderation
|
| 80 |
Returns "OutOfScope" if harmful, otherwise returns the original query.
|
| 81 |
"""
|
| 82 |
-
response =
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
categories = response.results[0].categories
|
| 88 |
-
|
| 89 |
-
# Check if any harmful category is flagged
|
| 90 |
if categories.get("violence_and_threats", False) or \
|
| 91 |
categories.get("hate_and_discrimination", False) or \
|
| 92 |
categories.get("dangerous_and_criminal_content", False) or \
|
|
@@ -94,7 +51,7 @@ def moderate_text(query: str) -> str:
|
|
| 94 |
return "OutOfScope"
|
| 95 |
return query
|
| 96 |
|
| 97 |
-
#
|
| 98 |
def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
|
| 99 |
if os.path.exists(store_dir):
|
| 100 |
print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
|
|
@@ -123,7 +80,7 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
|
|
| 123 |
vectorstore.save_local(store_dir)
|
| 124 |
return vectorstore
|
| 125 |
|
| 126 |
-
#
|
| 127 |
def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
|
| 128 |
class GeminiLangChainLLM(LLM):
|
| 129 |
def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
|
|
@@ -144,13 +101,13 @@ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
|
|
| 144 |
)
|
| 145 |
return rag_chain
|
| 146 |
|
| 147 |
-
#
|
| 148 |
classification_chain = get_classification_chain()
|
| 149 |
refusal_chain = get_refusal_chain() # Refusal chain will now use dynamic topic
|
| 150 |
tailor_chain = get_tailor_chain()
|
| 151 |
cleaner_chain = get_cleaner_chain()
|
| 152 |
|
| 153 |
-
#
|
| 154 |
wellness_csv = "AIChatbot.csv"
|
| 155 |
brand_csv = "BrandAI.csv"
|
| 156 |
wellness_store_dir = "faiss_wellness_store"
|
|
@@ -163,7 +120,7 @@ gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("
|
|
| 163 |
wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
|
| 164 |
brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
|
| 165 |
|
| 166 |
-
#
|
| 167 |
search_tool = DuckDuckGoSearchTool()
|
| 168 |
web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
|
| 169 |
managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
|
|
@@ -175,14 +132,10 @@ def do_web_search(query: str) -> str:
|
|
| 175 |
response = manager_agent.run(search_query)
|
| 176 |
return response
|
| 177 |
|
| 178 |
-
#
|
| 179 |
def run_with_chain(query: str) -> str:
|
| 180 |
print("DEBUG: Starting run_with_chain...")
|
| 181 |
|
| 182 |
-
|
| 183 |
-
# Ensure the query is a string
|
| 184 |
-
query = str(query).strip()
|
| 185 |
-
|
| 186 |
# 1) Moderate the query for harmful content
|
| 187 |
moderated_query = moderate_text(query)
|
| 188 |
if moderated_query == "OutOfScope":
|
|
|
|
| 1 |
import os
|
| 2 |
import getpass
|
| 3 |
+
from pydantic_ai import Agent # Import the Agent from pydantic_ai
|
| 4 |
+
from pydantic_ai.models.mistral import MistralModel # Import the Mistral model
|
| 5 |
import spacy # Import spaCy for NER functionality
|
| 6 |
import pandas as pd
|
| 7 |
from typing import Optional
|
|
|
|
| 12 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
| 13 |
import subprocess # Import subprocess to run shell commands
|
| 14 |
from langchain.llms.base import LLM # Import LLM
|
| 15 |
+
|
| 16 |
+
# Initialize Mistral agent using Pydantic AI
|
| 17 |
+
mistral_api_key = os.environ.get("MISTRAL_API_KEY") # Ensure your Mistral API key is set
|
| 18 |
+
mistral_model = MistralModel("mistral-large-latest", api_key=mistral_api_key) # Use a Mistral model
|
| 19 |
+
mistral_agent = Agent(mistral_model)
|
| 20 |
+
|
| 21 |
+
# Load spaCy model for NER and download the spaCy model if not already installed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
def install_spacy_model():
|
| 23 |
try:
|
|
|
|
| 24 |
spacy.load("en_core_web_sm")
|
| 25 |
print("spaCy model 'en_core_web_sm' is already installed.")
|
| 26 |
except OSError:
|
|
|
|
| 27 |
print("Downloading spaCy model 'en_core_web_sm'...")
|
| 28 |
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"], check=True)
|
| 29 |
print("spaCy model 'en_core_web_sm' downloaded successfully.")
|
|
|
|
| 34 |
# Load the spaCy model globally
|
| 35 |
nlp = spacy.load("en_core_web_sm")
|
| 36 |
|
| 37 |
+
# Function to moderate text using Pydantic AI's Mistral moderation model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
def moderate_text(query: str) -> str:
|
| 39 |
"""
|
| 40 |
+
Classifies the query as harmful or not using Mistral Moderation via Pydantic AI.
|
| 41 |
Returns "OutOfScope" if harmful, otherwise returns the original query.
|
| 42 |
"""
|
| 43 |
+
response = mistral_agent.call("classify", {"inputs": [query]})
|
| 44 |
+
categories = response['results'][0]['categories']
|
| 45 |
+
|
| 46 |
+
# Check if harmful content is flagged in moderation categories
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
if categories.get("violence_and_threats", False) or \
|
| 48 |
categories.get("hate_and_discrimination", False) or \
|
| 49 |
categories.get("dangerous_and_criminal_content", False) or \
|
|
|
|
| 51 |
return "OutOfScope"
|
| 52 |
return query
|
| 53 |
|
| 54 |
+
# 3) build_or_load_vectorstore (no changes)
|
| 55 |
def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
|
| 56 |
if os.path.exists(store_dir):
|
| 57 |
print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading...")
|
|
|
|
| 80 |
vectorstore.save_local(store_dir)
|
| 81 |
return vectorstore
|
| 82 |
|
| 83 |
+
# 4) Build RAG chain for Gemini (no changes)
|
| 84 |
def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
|
| 85 |
class GeminiLangChainLLM(LLM):
|
| 86 |
def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
|
|
|
|
| 101 |
)
|
| 102 |
return rag_chain
|
| 103 |
|
| 104 |
+
# 5) Initialize all the separate chains
|
| 105 |
classification_chain = get_classification_chain()
|
| 106 |
refusal_chain = get_refusal_chain() # Refusal chain will now use dynamic topic
|
| 107 |
tailor_chain = get_tailor_chain()
|
| 108 |
cleaner_chain = get_cleaner_chain()
|
| 109 |
|
| 110 |
+
# 6) Build our vectorstores + RAG chains
|
| 111 |
wellness_csv = "AIChatbot.csv"
|
| 112 |
brand_csv = "BrandAI.csv"
|
| 113 |
wellness_store_dir = "faiss_wellness_store"
|
|
|
|
| 120 |
wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
|
| 121 |
brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
|
| 122 |
|
| 123 |
+
# 7) Tools / Agents for web search (no changes)
|
| 124 |
search_tool = DuckDuckGoSearchTool()
|
| 125 |
web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
|
| 126 |
managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
|
|
|
|
| 132 |
response = manager_agent.run(search_query)
|
| 133 |
return response
|
| 134 |
|
| 135 |
+
# 8) Orchestrator: run_with_chain
|
| 136 |
def run_with_chain(query: str) -> str:
|
| 137 |
print("DEBUG: Starting run_with_chain...")
|
| 138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
# 1) Moderate the query for harmful content
|
| 140 |
moderated_query = moderate_text(query)
|
| 141 |
if moderated_query == "OutOfScope":
|