Spaces:
No application file
No application file
| import os | |
| import json | |
| import asyncio | |
| from typing import List, Optional, Dict, Any | |
| from loguru import logger | |
| from pydantic import BaseModel | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
| from langchain_openai import ChatOpenAI | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain_groq import ChatGroq | |
| from langchain_ollama import ChatOllama | |
| from config import MODEL_CONFIG, INITIAL_MESSAGES_CONFIG, MODE_CONFIG, NAVIGATION_FUNCTIONS, ROBOT_SPECIFIC_FUNCTIONS, ROBOT_NAMES | |
| class LLMRequestConfig(BaseModel): | |
| model_name: str = MODEL_CONFIG["default_model"] | |
| max_tokens: int = MODEL_CONFIG["max_tokens"] | |
| temperature: float = MODEL_CONFIG["temperature"] | |
| frequency_penalty: float = MODEL_CONFIG["frequency_penalty"] | |
| list_navigation_once: bool = True | |
| provider: str = "openai" | |
| # Resolve Pydantic namespace conflicts | |
| model_config = {"protected_namespaces": ()} | |
| def to_dict(self): | |
| return { | |
| "model_name": self.model_name, | |
| "max_tokens": self.max_tokens, | |
| "temperature": self.temperature, | |
| "frequency_penalty": self.frequency_penalty, | |
| "list_navigation_once": self.list_navigation_once, | |
| "provider": self.provider | |
| } | |
| def from_dict(cls, config_dict): | |
| return cls(**config_dict) | |
| class LLMRequestHandler: | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| def __init__(self, | |
| # Support both old and new parameter names for backward compatibility | |
| model_version: str = None, | |
| model_name: str = None, | |
| max_tokens: int = None, | |
| temperature: float = None, | |
| frequency_penalty: float = None, | |
| list_navigation_once: bool = None, | |
| model_type: str = None, | |
| provider: str = None, | |
| config: Optional[LLMRequestConfig] = None): | |
| # Initialize with config or from individual parameters | |
| if config: | |
| self.config = config | |
| else: | |
| # Create config from individual parameters, giving priority to new names | |
| self.config = LLMRequestConfig( | |
| model_name=model_name or model_version or MODEL_CONFIG["default_model"], | |
| max_tokens=max_tokens or MODEL_CONFIG["max_tokens"], | |
| temperature=temperature or MODEL_CONFIG["temperature"], | |
| frequency_penalty=frequency_penalty or MODEL_CONFIG["frequency_penalty"], | |
| list_navigation_once=list_navigation_once if list_navigation_once is not None else True, | |
| provider=provider or model_type or "openai" | |
| ) | |
| # Store parameters for easier access | |
| self.model_name = self.config.model_name | |
| self.model_version = self.model_name # Alias for backward compatibility | |
| self.max_tokens = self.config.max_tokens | |
| self.temperature = self.config.temperature | |
| self.frequency_penalty = self.config.frequency_penalty | |
| self.list_navigation_once = self.config.list_navigation_once | |
| self.provider = self.config.provider | |
| self.model_type = self.provider # Alias for backward compatibility | |
| # Store API keys | |
| self.openai_api_key = os.getenv("OPENAI_API_KEY") | |
| self.anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") | |
| self.groq_api_key = os.getenv("GROQ_API_KEY") | |
| # Create the appropriate LangChain LLM based on provider | |
| self._setup_llm() | |
| def _setup_llm(self): | |
| """Initialize the appropriate LangChain LLM based on provider.""" | |
| if "anthropic" in self.provider or "claude" in self.model_name: | |
| self.llm = ChatAnthropic( | |
| api_key=self.anthropic_api_key, | |
| model_name=self.model_name, | |
| max_tokens=self.max_tokens, | |
| temperature=self.temperature | |
| ) | |
| elif "ollama" in self.provider or "ollama" in self.model_name: | |
| self.llm = ChatOllama( | |
| model=self.model_name, | |
| max_tokens=self.max_tokens, | |
| temperature=self.temperature, | |
| base_url="http://host.docker.internal:11434" | |
| ) | |
| elif "groq" in self.provider or "llama" in self.model_name: | |
| self.llm = ChatGroq( | |
| api_key=self.groq_api_key, | |
| model_name=self.model_name, | |
| max_tokens=self.max_tokens, | |
| temperature=self.temperature, | |
| frequency_penalty=self.frequency_penalty | |
| ) | |
| else: # Default to OpenAI | |
| self.llm = ChatOpenAI( | |
| api_key=self.openai_api_key, | |
| model_name=self.model_name, | |
| max_tokens=self.max_tokens, | |
| temperature=self.temperature, | |
| frequency_penalty=self.frequency_penalty | |
| ) | |
| def get_config_dict(self): | |
| """Get a serializable configuration dictionary""" | |
| return self.config.to_dict() | |
| def create_from_config_dict(config_dict): | |
| """Create a new handler instance from a config dictionary""" | |
| config = LLMRequestConfig.from_dict(config_dict) | |
| return LLMRequestHandler(config=config) | |
| def load_object_data(self) -> Dict[str, Any]: | |
| """Load environment information (E) from a JSON file""" | |
| json_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'ros2_ws', 'src', 'breakdown_function_handler', 'object_database', 'object_database.json')) | |
| with open(json_path, 'r') as json_file: | |
| data = json.load(json_file) | |
| return self.format_env_object(data) | |
| def format_env_object(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """Format the environment data (E) for use in the prompt""" | |
| formatted_env_object = {} | |
| for obj in data: | |
| object_name = obj['object_name'] | |
| target_position = obj['target_position'] | |
| shape = obj['shape'] | |
| formatted_env_object[object_name] = { | |
| "position": { | |
| "x": target_position["x"], | |
| "y": target_position["y"] | |
| }, | |
| "shape": shape | |
| } | |
| return formatted_env_object | |
| def build_initial_messages(self, file_path: str, mode: str) -> List[Dict[str, str]]: | |
| """Build the initial prompt (P = (I, E, R, S))""" | |
| with open(file_path, 'r', encoding='utf-8') as file: | |
| user1 = file.read() # Example user instructions for few-shot learning (optional) | |
| system = INITIAL_MESSAGES_CONFIG["system"] | |
| # Load environment information (E) | |
| env_objects = self.load_object_data() | |
| # Create the user introduction with robot set (R), skills (S), and environment (E) | |
| user_intro = INITIAL_MESSAGES_CONFIG["user_intro"]["default"] + INITIAL_MESSAGES_CONFIG["user_intro"].get(mode, "") | |
| functions_description = MODE_CONFIG[mode].get("functions_description", "") | |
| # Format user introduction with the instruction (I), robot set (R), skills (S), and environment (E) | |
| user_intro = user_intro.format( | |
| library=NAVIGATION_FUNCTIONS+ROBOT_SPECIFIC_FUNCTIONS, | |
| env_objects=env_objects, | |
| robot_names=ROBOT_NAMES, | |
| fewshot_examples=user1, | |
| functions_description=functions_description | |
| ) | |
| assistant1 = INITIAL_MESSAGES_CONFIG["assistant"] | |
| # Construct the messages (system, user, assistant) | |
| messages = [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": user_intro}, | |
| {"role": "assistant", "content": assistant1} | |
| ] | |
| return messages | |
| def add_user_message(self, messages: List[Dict[str, str]], content: str) -> None: | |
| """Add a user message with natural language instruction (I)""" | |
| user_message = self.Message(role="user", content=content) | |
| messages.append(user_message.model_dump()) | |
| def _convert_to_langchain_messages(self, full_history: List[Dict[str, str]]): | |
| """Convert traditional message format to LangChain message objects""" | |
| lc_messages = [] | |
| for msg in full_history: | |
| if msg["role"] == "system": | |
| lc_messages.append(SystemMessage(content=msg["content"])) | |
| elif msg["role"] == "user": | |
| lc_messages.append(HumanMessage(content=msg["content"])) | |
| elif msg["role"] == "assistant": | |
| lc_messages.append(AIMessage(content=msg["content"])) | |
| return lc_messages | |
| async def make_completion(self, full_history: List[Dict[str, str]]) -> Optional[str]: | |
| """Make a completion request to the selected model using LangChain""" | |
| logger.debug(f"Using model: {self.model_name}") | |
| try: | |
| # Convert traditional messages to LangChain message format | |
| lc_messages = self._convert_to_langchain_messages(full_history) | |
| # Create a chat prompt template | |
| chat_prompt = ChatPromptTemplate.from_messages(lc_messages) | |
| # Get the response | |
| chain = chat_prompt | self.llm | |
| response = await chain.ainvoke({}) | |
| # Extract the content from the response | |
| return response.content if hasattr(response, 'content') else str(response) | |
| except Exception as e: | |
| logger.error(f"Error making completion: {e}") | |
| return None | |
| if __name__ == "__main__": | |
| async def main(): | |
| selected_model_index = 3 # 0 for OpenAI, 1 for Anthropic, 2 for LLaMA, 3 for Ollama | |
| model_options = MODEL_CONFIG["model_options"] | |
| # Choose the model based on selected_model_index | |
| if selected_model_index == 0: | |
| model = model_options[0] | |
| provider = "openai" | |
| elif selected_model_index == 1: | |
| model = model_options[4] | |
| provider = "anthropic" | |
| elif selected_model_index == 2: | |
| model = model_options[6] | |
| provider = "groq" | |
| elif selected_model_index == 3: | |
| model = "llama3" | |
| provider = "ollama" | |
| else: | |
| raise ValueError("Invalid selected_model_index") | |
| logger.debug("Starting test llm_request_handler with LangChain...") | |
| config = LLMRequestConfig( | |
| model_name=model, | |
| list_navigation_once=True, | |
| provider=provider | |
| ) | |
| handler = LLMRequestHandler(config=config) | |
| # Build initial messages based on the selected model | |
| if selected_model_index == 0: | |
| messages = handler.build_initial_messages("/root/share/QA_LLM_Module/prompts/swarm/dart.txt", "dart_gpt_4o") | |
| elif selected_model_index == 1: | |
| messages = handler.build_initial_messages("/root/share/QA_LLM_Module/prompts/swarm/dart.txt", "dart_claude_3_sonnet") | |
| elif selected_model_index == 2: | |
| messages = handler.build_initial_messages("/root/share/QA_LLM_Module/prompts/swarm/dart.txt", "dart_llama_3_3_70b") | |
| elif selected_model_index == 3: | |
| messages = handler.build_initial_messages("/root/share/QA_LLM_Module/prompts/swarm/dart.txt", "dart_ollama_llama3_1_8b") | |
| # Add a natural language instruction (I) to the prompt | |
| handler.add_user_message(messages, "Excavator 1 performs excavation, then excavator 2 performs, then dump 1 performs unload.") | |
| # Request completion from the model | |
| response = await handler.make_completion(messages) | |
| logger.debug(f"Response from make_completion: {response}") | |
| asyncio.run(main()) |