Spaces:
Sleeping
Sleeping
Update pipeline.py
Browse files- pipeline.py +28 -24
pipeline.py
CHANGED
|
@@ -10,7 +10,7 @@ from langchain.embeddings import HuggingFaceEmbeddings
|
|
| 10 |
from langchain.vectorstores import FAISS
|
| 11 |
from langchain.chains import RetrievalQA
|
| 12 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
| 13 |
-
from pydantic import BaseModel, ValidationError
|
| 14 |
from mistralai import Mistral
|
| 15 |
from langchain.prompts import PromptTemplate
|
| 16 |
|
|
@@ -25,6 +25,9 @@ from prompts import classification_prompt, refusal_prompt, tailor_prompt
|
|
| 25 |
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
|
| 26 |
client = Mistral(api_key=mistral_api_key)
|
| 27 |
|
|
|
|
|
|
|
|
|
|
| 28 |
# Load spaCy model for NER and download it if not already installed
|
| 29 |
def install_spacy_model():
|
| 30 |
try:
|
|
@@ -53,6 +56,17 @@ def extract_main_topic(query: str) -> str:
|
|
| 53 |
break
|
| 54 |
return main_topic if main_topic else "this topic"
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
# Function to classify query based on wellness topics
|
| 57 |
def classify_query(query: str) -> str:
|
| 58 |
wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"]
|
|
@@ -63,45 +77,31 @@ def classify_query(query: str) -> str:
|
|
| 63 |
classification = class_result.get("text", "").strip()
|
| 64 |
return classification if classification != "OutOfScope" else "OutOfScope"
|
| 65 |
|
| 66 |
-
#
|
| 67 |
-
|
| 68 |
-
text: str
|
| 69 |
-
|
| 70 |
-
# Function to validate the text input using Pydantic
|
| 71 |
-
def validate_text(query: str) -> str:
|
| 72 |
try:
|
| 73 |
-
#
|
| 74 |
-
|
| 75 |
-
return query
|
| 76 |
except ValidationError as e:
|
| 77 |
print(f"Error validating text: {e}")
|
| 78 |
return "Invalid text format."
|
| 79 |
-
|
| 80 |
-
# Function to moderate text using Mistral moderation API (synchronous version)
|
| 81 |
-
def moderate_text(query: str) -> str:
|
| 82 |
-
# Validate the text using Pydantic
|
| 83 |
-
validated_text = validate_text(query)
|
| 84 |
-
if validated_text == "Invalid text format.":
|
| 85 |
-
return validated_text
|
| 86 |
|
| 87 |
# Call the Mistral moderation API
|
| 88 |
response = client.classifiers.moderate_chat(
|
| 89 |
model="mistral-moderation-latest",
|
| 90 |
-
inputs=[{"role": "user", "content":
|
| 91 |
)
|
| 92 |
|
| 93 |
-
#
|
| 94 |
-
# check if it has a 'results' attribute, and then access its categories
|
| 95 |
if hasattr(response, 'results') and response.results:
|
| 96 |
categories = response.results[0].categories
|
| 97 |
-
# Check if harmful categories are present
|
| 98 |
if categories.get("violence_and_threats", False) or \
|
| 99 |
categories.get("hate_and_discrimination", False) or \
|
| 100 |
categories.get("dangerous_and_criminal_content", False) or \
|
| 101 |
categories.get("selfharm", False):
|
| 102 |
return "OutOfScope"
|
| 103 |
|
| 104 |
-
return
|
| 105 |
|
| 106 |
|
| 107 |
# Function to build or load the vector store from CSV data
|
|
@@ -173,7 +173,7 @@ def merge_responses(kb_answer: str, web_answer: str) -> str:
|
|
| 173 |
|
| 174 |
# Orchestrate the entire workflow
|
| 175 |
def run_pipeline(query: str) -> str:
|
| 176 |
-
# Moderate the query for harmful content
|
| 177 |
moderated_query = moderate_text(query)
|
| 178 |
if moderated_query == "OutOfScope":
|
| 179 |
return "Sorry, this query contains harmful or inappropriate content."
|
|
@@ -207,7 +207,7 @@ def run_pipeline(query: str) -> str:
|
|
| 207 |
final_refusal = tailor_chain.run({"response": refusal_text})
|
| 208 |
return final_refusal.strip()
|
| 209 |
|
| 210 |
-
# Initialize chains
|
| 211 |
classification_chain = get_classification_chain()
|
| 212 |
refusal_chain = get_refusal_chain()
|
| 213 |
tailor_chain = get_tailor_chain()
|
|
@@ -224,3 +224,7 @@ brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
|
|
| 224 |
gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
|
| 225 |
wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
|
| 226 |
brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from langchain.vectorstores import FAISS
|
| 11 |
from langchain.chains import RetrievalQA
|
| 12 |
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
|
| 13 |
+
from pydantic import BaseModel, ValidationError, validator
|
| 14 |
from mistralai import Mistral
|
| 15 |
from langchain.prompts import PromptTemplate
|
| 16 |
|
|
|
|
| 25 |
mistral_api_key = os.environ.get("MISTRAL_API_KEY")
|
| 26 |
client = Mistral(api_key=mistral_api_key)
|
| 27 |
|
| 28 |
+
# Initialize Pydantic AI Agent (for text validation)
|
| 29 |
+
pydantic_agent = Agent('mistral:mistral-large-latest', result_type=str)
|
| 30 |
+
|
| 31 |
# Load spaCy model for NER and download it if not already installed
|
| 32 |
def install_spacy_model():
|
| 33 |
try:
|
|
|
|
| 56 |
break
|
| 57 |
return main_topic if main_topic else "this topic"
|
| 58 |
|
| 59 |
+
# Pydantic model to handle string input validation
|
| 60 |
+
class QueryInput(BaseModel):
|
| 61 |
+
query: str
|
| 62 |
+
|
| 63 |
+
# Validator to ensure the query is always a string
|
| 64 |
+
@validator('query')
|
| 65 |
+
def check_query_is_string(cls, v):
|
| 66 |
+
if not isinstance(v, str):
|
| 67 |
+
raise ValueError("Query must be a valid string.")
|
| 68 |
+
return v
|
| 69 |
+
|
| 70 |
# Function to classify query based on wellness topics
|
| 71 |
def classify_query(query: str) -> str:
|
| 72 |
wellness_keywords = ["box breathing", "meditation", "yoga", "mindfulness", "breathing exercises"]
|
|
|
|
| 77 |
classification = class_result.get("text", "").strip()
|
| 78 |
return classification if classification != "OutOfScope" else "OutOfScope"
|
| 79 |
|
| 80 |
+
# Function to moderate text using Mistral moderation API (sync version)
|
| 81 |
+
def moderate_text(query: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
try:
|
| 83 |
+
# Use Pydantic to validate text input
|
| 84 |
+
query_input = QueryInput(query=query) # This will validate that the query is a string
|
|
|
|
| 85 |
except ValidationError as e:
|
| 86 |
print(f"Error validating text: {e}")
|
| 87 |
return "Invalid text format."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
# Call the Mistral moderation API
|
| 90 |
response = client.classifiers.moderate_chat(
|
| 91 |
model="mistral-moderation-latest",
|
| 92 |
+
inputs=[{"role": "user", "content": query}]
|
| 93 |
)
|
| 94 |
|
| 95 |
+
# Check if harmful categories are present in the response
|
|
|
|
| 96 |
if hasattr(response, 'results') and response.results:
|
| 97 |
categories = response.results[0].categories
|
|
|
|
| 98 |
if categories.get("violence_and_threats", False) or \
|
| 99 |
categories.get("hate_and_discrimination", False) or \
|
| 100 |
categories.get("dangerous_and_criminal_content", False) or \
|
| 101 |
categories.get("selfharm", False):
|
| 102 |
return "OutOfScope"
|
| 103 |
|
| 104 |
+
return query
|
| 105 |
|
| 106 |
|
| 107 |
# Function to build or load the vector store from CSV data
|
|
|
|
| 173 |
|
| 174 |
# Orchestrate the entire workflow
|
| 175 |
def run_pipeline(query: str) -> str:
|
| 176 |
+
# Moderate the query for harmful content
|
| 177 |
moderated_query = moderate_text(query)
|
| 178 |
if moderated_query == "OutOfScope":
|
| 179 |
return "Sorry, this query contains harmful or inappropriate content."
|
|
|
|
| 207 |
final_refusal = tailor_chain.run({"response": refusal_text})
|
| 208 |
return final_refusal.strip()
|
| 209 |
|
| 210 |
+
# Initialize chains
|
| 211 |
classification_chain = get_classification_chain()
|
| 212 |
refusal_chain = get_refusal_chain()
|
| 213 |
tailor_chain = get_tailor_chain()
|
|
|
|
| 224 |
gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
|
| 225 |
wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
|
| 226 |
brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
|
| 227 |
+
|
| 228 |
+
# Function to wrap up and run the chain
|
| 229 |
+
def run_with_chain(query: str) -> str:
|
| 230 |
+
return run_pipeline(query)
|