Spaces:
Runtime error
Runtime error
| from typing import Dict, List, Optional, Any | |
| from crewai import Agent, Task | |
| import logging | |
| from utils.log_manager import LogManager | |
| from pydantic import Field, BaseModel, ConfigDict | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| class BaseWellnessAgent(Agent): | |
| """Base agent class with Mistral LLM support""" | |
| # Allow arbitrary types in model | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| # Define fields that will be used | |
| log_manager: LogManager = Field(default_factory=LogManager) | |
| logger: logging.Logger = Field(default=None) | |
| config: Dict = Field(default_factory=dict) | |
| model: Any = Field(default=None) | |
| tokenizer: Any = Field(default=None) | |
| agent_type: str = Field(default="base") | |
| def __init__(self, model_config: Dict, agent_type: str, **kwargs): | |
| # Initialize the CrewAI agent first with required fields | |
| super().__init__( | |
| role=kwargs.get("role", "Wellness Support Agent"), | |
| goal=kwargs.get("goal", "Support mental wellness"), | |
| backstory=kwargs.get("backstory", "I am an AI agent specialized in mental health support."), | |
| verbose=kwargs.get("verbose", True), | |
| allow_delegation=kwargs.get("allow_delegation", False), | |
| tools=kwargs.get("tools", []), | |
| **kwargs | |
| ) | |
| # Initialize logging and configuration | |
| self.config = model_config | |
| self.agent_type = agent_type | |
| self.logger = self.log_manager.get_agent_logger(agent_type) | |
| # Initialize Mistral model | |
| self._initialize_model() | |
| self.logger.info(f"{agent_type.capitalize()} Agent initialized") | |
| def _initialize_model(self): | |
| """Initialize the Mistral model""" | |
| try: | |
| model_config = self.config[self.agent_type] | |
| self.logger.info(f"Initializing Mistral model: {model_config['model_id']}") | |
| # Initialize tokenizer and model | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_config["model_id"]) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_config["model_id"], | |
| torch_dtype=torch.float32, | |
| device_map="auto", | |
| load_in_4bit=True | |
| ) | |
| self.logger.info("Mistral model initialized successfully") | |
| except Exception as e: | |
| self.logger.error(f"Error initializing Mistral model: {str(e)}") | |
| raise | |
| def _generate_response(self, input_text: str) -> str: | |
| """Generate response using Mistral model""" | |
| try: | |
| # Prepare input with instruction template | |
| template = self.config[self.agent_type]["instruction_template"] | |
| prompt = template.format(input=input_text) | |
| # Tokenize input | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| # Generate response | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_length=self.config[self.agent_type].get("max_length", 4096), | |
| temperature=self.config[self.agent_type].get("temperature", 0.7), | |
| top_p=self.config[self.agent_type].get("top_p", 0.95), | |
| repetition_penalty=self.config[self.agent_type].get("repetition_penalty", 1.1), | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| # Decode and clean response | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = response.replace(prompt, "").strip() | |
| return response | |
| except Exception as e: | |
| self.logger.error(f"Error generating response: {str(e)}") | |
| return "I apologize, but I encountered an error generating a response." | |
| def execute_task(self, task: Task) -> str: | |
| """Execute a task assigned to the agent""" | |
| self.logger.info(f"Executing task: {task.description}") | |
| try: | |
| # Process the task description as a message | |
| result = self.process_message(task.description) | |
| return result["message"] | |
| except Exception as e: | |
| self.logger.error(f"Error executing task: {str(e)}") | |
| return "I apologize, but I encountered an error processing your request." | |
| def process_message(self, message: str, context: Dict = None) -> Dict: | |
| """Process a message and return a response""" | |
| self.logger.info("Processing message") | |
| context = context or {} | |
| try: | |
| # Generate response using Mistral | |
| response = self._generate_response(message) | |
| return { | |
| "message": response, | |
| "agent_type": self.agent_type, | |
| "task_type": "dialogue" | |
| } | |
| except Exception as e: | |
| self.logger.error(f"Error processing message: {str(e)}") | |
| return { | |
| "message": "I apologize, but I encountered an error. Let me try a different approach.", | |
| "agent_type": self.agent_type, | |
| "task_type": "error_recovery" | |
| } | |
| def get_status(self) -> Dict: | |
| """Get the current status of the agent""" | |
| return { | |
| "type": self.agent_type, | |
| "ready": bool(self.model and self.tokenizer), | |
| "tools_available": len(self.tools) | |
| } |