|
|
"""Pass input through a moderation endpoint.""" |
|
|
|
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
from langchain_core.callbacks import ( |
|
|
AsyncCallbackManagerForChainRun, |
|
|
CallbackManagerForChainRun, |
|
|
) |
|
|
from langchain_core.utils import check_package_version, get_from_dict_or_env |
|
|
from pydantic import Field, model_validator |
|
|
|
|
|
from langchain.chains.base import Chain |
|
|
|
|
|
|
|
|
class OpenAIModerationChain(Chain): |
|
|
"""Pass input through a moderation endpoint. |
|
|
|
|
|
To use, you should have the ``openai`` python package installed, and the |
|
|
environment variable ``OPENAI_API_KEY`` set with your API key. |
|
|
|
|
|
Any parameters that are valid to be passed to the openai.create call can be passed |
|
|
in, even if not explicitly saved on this class. |
|
|
|
|
|
Example: |
|
|
.. code-block:: python |
|
|
|
|
|
from langchain.chains import OpenAIModerationChain |
|
|
moderation = OpenAIModerationChain() |
|
|
""" |
|
|
|
|
|
client: Any = None |
|
|
async_client: Any = None |
|
|
model_name: Optional[str] = None |
|
|
"""Moderation model name to use.""" |
|
|
error: bool = False |
|
|
"""Whether or not to error if bad content was found.""" |
|
|
input_key: str = "input" |
|
|
output_key: str = "output" |
|
|
openai_api_key: Optional[str] = None |
|
|
openai_organization: Optional[str] = None |
|
|
openai_pre_1_0: bool = Field(default=False) |
|
|
|
|
|
@model_validator(mode="before") |
|
|
@classmethod |
|
|
def validate_environment(cls, values: Dict) -> Any: |
|
|
"""Validate that api key and python package exists in environment.""" |
|
|
openai_api_key = get_from_dict_or_env( |
|
|
values, "openai_api_key", "OPENAI_API_KEY" |
|
|
) |
|
|
openai_organization = get_from_dict_or_env( |
|
|
values, |
|
|
"openai_organization", |
|
|
"OPENAI_ORGANIZATION", |
|
|
default="", |
|
|
) |
|
|
try: |
|
|
import openai |
|
|
|
|
|
openai.api_key = openai_api_key |
|
|
if openai_organization: |
|
|
openai.organization = openai_organization |
|
|
values["openai_pre_1_0"] = False |
|
|
try: |
|
|
check_package_version("openai", gte_version="1.0") |
|
|
except ValueError: |
|
|
values["openai_pre_1_0"] = True |
|
|
if values["openai_pre_1_0"]: |
|
|
values["client"] = openai.Moderation |
|
|
else: |
|
|
values["client"] = openai.OpenAI(api_key=openai_api_key) |
|
|
values["async_client"] = openai.AsyncOpenAI(api_key=openai_api_key) |
|
|
|
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"Could not import openai python package. " |
|
|
"Please install it with `pip install openai`." |
|
|
) |
|
|
return values |
|
|
|
|
|
@property |
|
|
def input_keys(self) -> List[str]: |
|
|
"""Expect input key. |
|
|
|
|
|
:meta private: |
|
|
""" |
|
|
return [self.input_key] |
|
|
|
|
|
@property |
|
|
def output_keys(self) -> List[str]: |
|
|
"""Return output key. |
|
|
|
|
|
:meta private: |
|
|
""" |
|
|
return [self.output_key] |
|
|
|
|
|
def _moderate(self, text: str, results: Any) -> str: |
|
|
if self.openai_pre_1_0: |
|
|
condition = results["flagged"] |
|
|
else: |
|
|
condition = results.flagged |
|
|
if condition: |
|
|
error_str = "Text was found that violates OpenAI's content policy." |
|
|
if self.error: |
|
|
raise ValueError(error_str) |
|
|
else: |
|
|
return error_str |
|
|
return text |
|
|
|
|
|
def _call( |
|
|
self, |
|
|
inputs: Dict[str, Any], |
|
|
run_manager: Optional[CallbackManagerForChainRun] = None, |
|
|
) -> Dict[str, Any]: |
|
|
text = inputs[self.input_key] |
|
|
if self.openai_pre_1_0: |
|
|
results = self.client.create(text) |
|
|
output = self._moderate(text, results["results"][0]) |
|
|
else: |
|
|
results = self.client.moderations.create(input=text) |
|
|
output = self._moderate(text, results.results[0]) |
|
|
return {self.output_key: output} |
|
|
|
|
|
async def _acall( |
|
|
self, |
|
|
inputs: Dict[str, Any], |
|
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None, |
|
|
) -> Dict[str, Any]: |
|
|
if self.openai_pre_1_0: |
|
|
return await super()._acall(inputs, run_manager=run_manager) |
|
|
text = inputs[self.input_key] |
|
|
results = await self.async_client.moderations.create(input=text) |
|
|
output = self._moderate(text, results.results[0]) |
|
|
return {self.output_key: output} |
|
|
|