Spaces:
Sleeping
Sleeping
Update pipeline.py
Browse files- pipeline.py +11 -24
pipeline.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
| 1 |
import os
|
| 2 |
import getpass
|
| 3 |
-
from pydantic_ai import Agent # Import the Agent from pydantic_ai
|
| 4 |
-
from pydantic_ai.models.mistral import MistralModel # Import the Mistral model
|
| 5 |
import spacy # Import spaCy for NER functionality
|
| 6 |
import pandas as pd
|
| 7 |
from typing import Optional
|
|
@@ -13,25 +11,12 @@ from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMMod
|
|
| 13 |
import subprocess # Import subprocess to run shell commands
|
| 14 |
from langchain.llms.base import LLM # Import LLM
|
| 15 |
|
| 16 |
-
|
| 17 |
-
from
|
| 18 |
-
from tailor_chain import get_tailor_chain
|
| 19 |
-
from cleaner_chain import get_cleaner_chain, CleanerChain
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
# 1) Environment: set up keys if missing
|
| 23 |
-
if not os.environ.get("GEMINI_API_KEY"):
|
| 24 |
-
os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
|
| 25 |
-
if not os.environ.get("GROQ_API_KEY"):
|
| 26 |
-
os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
|
| 27 |
-
if not os.environ.get("MISTRAL_API_KEY"):
|
| 28 |
-
os.environ["MISTRAL_API_KEY"] = getpass.getpass("Enter your Mistral API Key: ")
|
| 29 |
-
|
| 30 |
-
# Initialize Mistral agent using Pydantic AI
|
| 31 |
|
|
|
|
| 32 |
mistral_api_key = os.environ.get("MISTRAL_API_KEY") # Ensure your Mistral API key is set
|
| 33 |
-
|
| 34 |
-
mistral_agent = Agent(mistral_model)
|
| 35 |
|
| 36 |
# Load spaCy model for NER and download the spaCy model if not already installed
|
| 37 |
def install_spacy_model():
|
|
@@ -49,21 +34,21 @@ install_spacy_model()
|
|
| 49 |
# Load the spaCy model globally
|
| 50 |
nlp = spacy.load("en_core_web_sm")
|
| 51 |
|
| 52 |
-
# Function to moderate text using
|
| 53 |
def moderate_text(query: str) -> str:
|
| 54 |
"""
|
| 55 |
-
Classifies the query as harmful or not using Mistral Moderation via
|
| 56 |
Returns "OutOfScope" if harmful, otherwise returns the original query.
|
| 57 |
"""
|
| 58 |
# Use the moderation API to evaluate if the query is harmful
|
| 59 |
-
response =
|
| 60 |
model="mistral-moderation-latest",
|
| 61 |
inputs=[
|
| 62 |
{"role": "user", "content": query},
|
| 63 |
],
|
| 64 |
)
|
| 65 |
-
|
| 66 |
-
#
|
| 67 |
categories = response['results'][0]['categories']
|
| 68 |
|
| 69 |
# Check if harmful content is flagged in moderation categories
|
|
@@ -72,7 +57,9 @@ def moderate_text(query: str) -> str:
|
|
| 72 |
categories.get("dangerous_and_criminal_content", False) or \
|
| 73 |
categories.get("selfharm", False):
|
| 74 |
return "OutOfScope"
|
|
|
|
| 75 |
return query
|
|
|
|
| 76 |
# 3) build_or_load_vectorstore (no changes)
|
| 77 |
def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
|
| 78 |
if os.path.exists(store_dir):
|
|
|
|
| 1 |
import os
|
| 2 |
import getpass
|
|
|
|
|
|
|
| 3 |
import spacy # Import spaCy for NER functionality
|
| 4 |
import pandas as pd
|
| 5 |
from typing import Optional
|
|
|
|
| 11 |
import subprocess # Import subprocess to run shell commands
|
| 12 |
from langchain.llms.base import LLM # Import LLM
|
| 13 |
|
| 14 |
+
# Mistral Client Setup
|
| 15 |
+
from mistralai import Mistral # Import the Mistral client
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
+
# Initialize Mistral API client
|
| 18 |
mistral_api_key = os.environ.get("MISTRAL_API_KEY") # Ensure your Mistral API key is set
|
| 19 |
+
client = Mistral(api_key=mistral_api_key)
|
|
|
|
| 20 |
|
| 21 |
# Load spaCy model for NER and download the spaCy model if not already installed
|
| 22 |
def install_spacy_model():
|
|
|
|
| 34 |
# Load the spaCy model globally
|
| 35 |
nlp = spacy.load("en_core_web_sm")
|
| 36 |
|
| 37 |
+
# Function to moderate text using Mistral moderation API
|
| 38 |
def moderate_text(query: str) -> str:
|
| 39 |
"""
|
| 40 |
+
Classifies the query as harmful or not using Mistral Moderation via Mistral API.
|
| 41 |
Returns "OutOfScope" if harmful, otherwise returns the original query.
|
| 42 |
"""
|
| 43 |
# Use the moderation API to evaluate if the query is harmful
|
| 44 |
+
response = client.classifiers.moderate_chat(
|
| 45 |
model="mistral-moderation-latest",
|
| 46 |
inputs=[
|
| 47 |
{"role": "user", "content": query},
|
| 48 |
],
|
| 49 |
)
|
| 50 |
+
|
| 51 |
+
# Extracting category scores from response
|
| 52 |
categories = response['results'][0]['categories']
|
| 53 |
|
| 54 |
# Check if harmful content is flagged in moderation categories
|
|
|
|
| 57 |
categories.get("dangerous_and_criminal_content", False) or \
|
| 58 |
categories.get("selfharm", False):
|
| 59 |
return "OutOfScope"
|
| 60 |
+
|
| 61 |
return query
|
| 62 |
+
|
| 63 |
# 3) build_or_load_vectorstore (no changes)
|
| 64 |
def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
|
| 65 |
if os.path.exists(store_dir):
|