Spaces:
Sleeping
Sleeping
| """ | |
| 🤖 Fagun Browser Automation Testing Agent - LLM Provider | |
| ======================================================== | |
| LLM provider utilities and configurations for the Fagun Browser Automation Testing Agent. | |
| Author: Mejbaur Bahar Fagun | |
| Role: Software Engineer in Test | |
| LinkedIn: https://www.linkedin.com/in/mejbaur/ | |
| """ | |
| from openai import OpenAI | |
| import pdb | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.globals import get_llm_cache | |
| from langchain_core.language_models.base import ( | |
| BaseLanguageModel, | |
| LangSmithParams, | |
| LanguageModelInput, | |
| ) | |
| import os | |
| from langchain_core.load import dumpd, dumps | |
| from langchain_core.messages import ( | |
| AIMessage, | |
| SystemMessage, | |
| AnyMessage, | |
| BaseMessage, | |
| BaseMessageChunk, | |
| HumanMessage, | |
| convert_to_messages, | |
| message_chunk_to_message, | |
| ) | |
| from langchain_core.outputs import ( | |
| ChatGeneration, | |
| ChatGenerationChunk, | |
| ChatResult, | |
| LLMResult, | |
| RunInfo, | |
| ) | |
| from langchain_ollama import ChatOllama | |
| from langchain_core.output_parsers.base import OutputParserLike | |
| from langchain_core.runnables import Runnable, RunnableConfig | |
| from langchain_core.tools import BaseTool | |
| from typing import ( | |
| TYPE_CHECKING, | |
| Any, | |
| Callable, | |
| Literal, | |
| Optional, | |
| Union, | |
| cast, List, | |
| ) | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain_mistralai import ChatMistralAI | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_ollama import ChatOllama | |
| from langchain_openai import AzureChatOpenAI, ChatOpenAI | |
| from langchain_ibm import ChatWatsonx | |
| from langchain_aws import ChatBedrock | |
| from pydantic import SecretStr | |
| from src.utils import config | |
| class DeepSeekR1ChatOpenAI(ChatOpenAI): | |
| def __init__(self, *args: Any, **kwargs: Any) -> None: | |
| super().__init__(*args, **kwargs) | |
| self.client = OpenAI( | |
| base_url=kwargs.get("base_url"), | |
| api_key=kwargs.get("api_key") | |
| ) | |
| async def ainvoke( | |
| self, | |
| input: LanguageModelInput, | |
| config: Optional[RunnableConfig] = None, | |
| *, | |
| stop: Optional[list[str]] = None, | |
| **kwargs: Any, | |
| ) -> AIMessage: | |
| message_history = [] | |
| for input_ in input: | |
| if isinstance(input_, SystemMessage): | |
| message_history.append({"role": "system", "content": input_.content}) | |
| elif isinstance(input_, AIMessage): | |
| message_history.append({"role": "assistant", "content": input_.content}) | |
| else: | |
| message_history.append({"role": "user", "content": input_.content}) | |
| response = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=message_history | |
| ) | |
| reasoning_content = response.choices[0].message.reasoning_content | |
| content = response.choices[0].message.content | |
| return AIMessage(content=content, reasoning_content=reasoning_content) | |
| def invoke( | |
| self, | |
| input: LanguageModelInput, | |
| config: Optional[RunnableConfig] = None, | |
| *, | |
| stop: Optional[list[str]] = None, | |
| **kwargs: Any, | |
| ) -> AIMessage: | |
| message_history = [] | |
| for input_ in input: | |
| if isinstance(input_, SystemMessage): | |
| message_history.append({"role": "system", "content": input_.content}) | |
| elif isinstance(input_, AIMessage): | |
| message_history.append({"role": "assistant", "content": input_.content}) | |
| else: | |
| message_history.append({"role": "user", "content": input_.content}) | |
| response = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=message_history | |
| ) | |
| reasoning_content = response.choices[0].message.reasoning_content | |
| content = response.choices[0].message.content | |
| return AIMessage(content=content, reasoning_content=reasoning_content) | |
| class DeepSeekR1ChatOllama(ChatOllama): | |
| async def ainvoke( | |
| self, | |
| input: LanguageModelInput, | |
| config: Optional[RunnableConfig] = None, | |
| *, | |
| stop: Optional[list[str]] = None, | |
| **kwargs: Any, | |
| ) -> AIMessage: | |
| org_ai_message = await super().ainvoke(input=input) | |
| org_content = org_ai_message.content | |
| reasoning_content = org_content.split("</think>")[0].replace("<think>", "") | |
| content = org_content.split("</think>")[1] | |
| if "**JSON Response:**" in content: | |
| content = content.split("**JSON Response:**")[-1] | |
| return AIMessage(content=content, reasoning_content=reasoning_content) | |
| def invoke( | |
| self, | |
| input: LanguageModelInput, | |
| config: Optional[RunnableConfig] = None, | |
| *, | |
| stop: Optional[list[str]] = None, | |
| **kwargs: Any, | |
| ) -> AIMessage: | |
| org_ai_message = super().invoke(input=input) | |
| org_content = org_ai_message.content | |
| reasoning_content = org_content.split("</think>")[0].replace("<think>", "") | |
| content = org_content.split("</think>")[1] | |
| if "**JSON Response:**" in content: | |
| content = content.split("**JSON Response:**")[-1] | |
| return AIMessage(content=content, reasoning_content=reasoning_content) | |
| def get_llm_model(provider: str, **kwargs): | |
| """ | |
| Get LLM model | |
| :param provider: LLM provider | |
| :param kwargs: | |
| :return: | |
| """ | |
| if provider not in ["ollama", "bedrock"]: | |
| env_var = f"{provider.upper()}_API_KEY" | |
| api_key = kwargs.get("api_key", "") or os.getenv(env_var, "") | |
| if not api_key: | |
| provider_display = config.PROVIDER_DISPLAY_NAMES.get(provider, provider.upper()) | |
| error_msg = f"💥 {provider_display} API key not found! 🔑 Please set the `{env_var}` environment variable or provide it in the UI." | |
| raise ValueError(error_msg) | |
| kwargs["api_key"] = api_key | |
| if provider == "anthropic": | |
| if not kwargs.get("base_url", ""): | |
| base_url = "https://api.anthropic.com" | |
| else: | |
| base_url = kwargs.get("base_url") | |
| return ChatAnthropic( | |
| model=kwargs.get("model_name", "claude-3-5-sonnet-20241022"), | |
| temperature=kwargs.get("temperature", 0.0), | |
| base_url=base_url, | |
| api_key=api_key, | |
| ) | |
| elif provider == 'mistral': | |
| if not kwargs.get("base_url", ""): | |
| base_url = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1") | |
| else: | |
| base_url = kwargs.get("base_url") | |
| if not kwargs.get("api_key", ""): | |
| api_key = os.getenv("MISTRAL_API_KEY", "") | |
| else: | |
| api_key = kwargs.get("api_key") | |
| return ChatMistralAI( | |
| model=kwargs.get("model_name", "mistral-large-latest"), | |
| temperature=kwargs.get("temperature", 0.0), | |
| base_url=base_url, | |
| api_key=api_key, | |
| ) | |
| elif provider == "openai": | |
| if not kwargs.get("base_url", ""): | |
| base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1") | |
| else: | |
| base_url = kwargs.get("base_url") | |
| return ChatOpenAI( | |
| model=kwargs.get("model_name", "gpt-4o"), | |
| temperature=kwargs.get("temperature", 0.0), | |
| base_url=base_url, | |
| api_key=api_key, | |
| ) | |
| elif provider == "grok": | |
| if not kwargs.get("base_url", ""): | |
| base_url = os.getenv("GROK_ENDPOINT", "https://api.x.ai/v1") | |
| else: | |
| base_url = kwargs.get("base_url") | |
| return ChatOpenAI( | |
| model=kwargs.get("model_name", "grok-3"), | |
| temperature=kwargs.get("temperature", 0.0), | |
| base_url=base_url, | |
| api_key=api_key, | |
| ) | |
| elif provider == "deepseek": | |
| if not kwargs.get("base_url", ""): | |
| base_url = os.getenv("DEEPSEEK_ENDPOINT", "") | |
| else: | |
| base_url = kwargs.get("base_url") | |
| if kwargs.get("model_name", "deepseek-chat") == "deepseek-reasoner": | |
| return DeepSeekR1ChatOpenAI( | |
| model=kwargs.get("model_name", "deepseek-reasoner"), | |
| temperature=kwargs.get("temperature", 0.0), | |
| base_url=base_url, | |
| api_key=api_key, | |
| ) | |
| else: | |
| return ChatOpenAI( | |
| model=kwargs.get("model_name", "deepseek-chat"), | |
| temperature=kwargs.get("temperature", 0.0), | |
| base_url=base_url, | |
| api_key=api_key, | |
| ) | |
| elif provider == "google": | |
| return ChatGoogleGenerativeAI( | |
| model=kwargs.get("model_name", "gemini-2.0-flash-exp"), | |
| temperature=kwargs.get("temperature", 0.0), | |
| api_key=api_key, | |
| ) | |
| elif provider == "ollama": | |
| if not kwargs.get("base_url", ""): | |
| base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434") | |
| else: | |
| base_url = kwargs.get("base_url") | |
| if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"): | |
| return DeepSeekR1ChatOllama( | |
| model=kwargs.get("model_name", "deepseek-r1:14b"), | |
| temperature=kwargs.get("temperature", 0.0), | |
| num_ctx=kwargs.get("num_ctx", 32000), | |
| base_url=base_url, | |
| ) | |
| else: | |
| return ChatOllama( | |
| model=kwargs.get("model_name", "qwen2.5:7b"), | |
| temperature=kwargs.get("temperature", 0.0), | |
| num_ctx=kwargs.get("num_ctx", 32000), | |
| num_predict=kwargs.get("num_predict", 1024), | |
| base_url=base_url, | |
| ) | |
| elif provider == "azure_openai": | |
| if not kwargs.get("base_url", ""): | |
| base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "") | |
| else: | |
| base_url = kwargs.get("base_url") | |
| api_version = kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview") | |
| return AzureChatOpenAI( | |
| model=kwargs.get("model_name", "gpt-4o"), | |
| temperature=kwargs.get("temperature", 0.0), | |
| api_version=api_version, | |
| azure_endpoint=base_url, | |
| api_key=api_key, | |
| ) | |
| elif provider == "alibaba": | |
| if not kwargs.get("base_url", ""): | |
| base_url = os.getenv("ALIBABA_ENDPOINT", "https://dashscope.aliyuncs.com/compatible-mode/v1") | |
| else: | |
| base_url = kwargs.get("base_url") | |
| return ChatOpenAI( | |
| model=kwargs.get("model_name", "qwen-plus"), | |
| temperature=kwargs.get("temperature", 0.0), | |
| base_url=base_url, | |
| api_key=api_key, | |
| ) | |
| elif provider == "ibm": | |
| parameters = { | |
| "temperature": kwargs.get("temperature", 0.0), | |
| "max_tokens": kwargs.get("num_ctx", 32000) | |
| } | |
| if not kwargs.get("base_url", ""): | |
| base_url = os.getenv("IBM_ENDPOINT", "https://us-south.ml.cloud.ibm.com") | |
| else: | |
| base_url = kwargs.get("base_url") | |
| return ChatWatsonx( | |
| model_id=kwargs.get("model_name", "ibm/granite-vision-3.1-2b-preview"), | |
| url=base_url, | |
| project_id=os.getenv("IBM_PROJECT_ID"), | |
| apikey=os.getenv("IBM_API_KEY"), | |
| params=parameters | |
| ) | |
| elif provider == "moonshot": | |
| return ChatOpenAI( | |
| model=kwargs.get("model_name", "moonshot-v1-32k-vision-preview"), | |
| temperature=kwargs.get("temperature", 0.0), | |
| base_url=os.getenv("MOONSHOT_ENDPOINT"), | |
| api_key=os.getenv("MOONSHOT_API_KEY"), | |
| ) | |
| elif provider == "unbound": | |
| return ChatOpenAI( | |
| model=kwargs.get("model_name", "gpt-4o-mini"), | |
| temperature=kwargs.get("temperature", 0.0), | |
| base_url=os.getenv("UNBOUND_ENDPOINT", "https://api.getunbound.ai"), | |
| api_key=api_key, | |
| ) | |
| elif provider == "siliconflow": | |
| if not kwargs.get("api_key", ""): | |
| api_key = os.getenv("SiliconFLOW_API_KEY", "") | |
| else: | |
| api_key = kwargs.get("api_key") | |
| if not kwargs.get("base_url", ""): | |
| base_url = os.getenv("SiliconFLOW_ENDPOINT", "") | |
| else: | |
| base_url = kwargs.get("base_url") | |
| return ChatOpenAI( | |
| api_key=api_key, | |
| base_url=base_url, | |
| model_name=kwargs.get("model_name", "Qwen/QwQ-32B"), | |
| temperature=kwargs.get("temperature", 0.0), | |
| ) | |
| elif provider == "modelscope": | |
| if not kwargs.get("api_key", ""): | |
| api_key = os.getenv("MODELSCOPE_API_KEY", "") | |
| else: | |
| api_key = kwargs.get("api_key") | |
| if not kwargs.get("base_url", ""): | |
| base_url = os.getenv("MODELSCOPE_ENDPOINT", "") | |
| else: | |
| base_url = kwargs.get("base_url") | |
| return ChatOpenAI( | |
| api_key=api_key, | |
| base_url=base_url, | |
| model_name=kwargs.get("model_name", "Qwen/QwQ-32B"), | |
| temperature=kwargs.get("temperature", 0.0), | |
| extra_body = {"enable_thinking": False} | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported provider: {provider}") | |