2stacks's picture
Add environment variable configuration and enhanced Ollama model verification
5c62b72 verified
from smolagents import CodeAgent, DuckDuckGoSearchTool, FinalAnswerTool, InferenceClientModel, LiteLLMModel, tool
import os
import requests
import pytz
import yaml
from datetime import datetime
from Gradio_UI import GradioUI
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# Configuration
HF_MODEL_ID = os.getenv("HF_MODEL_ID", "Qwen/Qwen2.5-Coder-32B-Instruct")
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
OLLAMA_MODEL_ID = os.getenv("OLLAMA_MODEL_ID", "qwen2.5-coder:32b")
def is_ollama_available(base_url=None, timeout=2):
"""Check if Ollama service is running and the specified model exists."""
if base_url is None:
base_url = OLLAMA_BASE_URL
try:
response = requests.get(f"{base_url}/api/tags", timeout=timeout)
if response.status_code != 200:
print(f"Ollama service check failed: HTTP {response.status_code} from {base_url}/api/tags")
return False
# Parse the response to get available models
data = response.json()
available_models = [model.get('name', '') for model in data.get('models', [])]
# Check if the model exists in available models
if OLLAMA_MODEL_ID not in available_models:
print(f"Model '{OLLAMA_MODEL_ID}' not found in Ollama.")
print(f"Available models: {', '.join(available_models) if available_models else 'None'}")
return False
print(f"Ollama service is available and model '{OLLAMA_MODEL_ID}' found.")
return True
except (requests.RequestException, ConnectionError) as e:
print(f"Failed to connect to Ollama service at {base_url}: {type(e).__name__}: {e}")
return False
except (ValueError, KeyError) as e:
print(f"Failed to parse Ollama API response: {type(e).__name__}: {e}")
return False
@tool
def get_current_time_in_timezone(timezone: str) -> str:
"""A tool that fetches the current local time in a specified timezone.
Args:
timezone: A string representing a valid timezone (e.g., 'America/New_York').
"""
try:
# Create timezone object
tz = pytz.timezone(timezone)
# Get current time in that timezone
local_time = datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
return f"The current local time in {timezone} is: {local_time}"
except Exception as e:
return f"Error fetching time for timezone '{timezone}': {str(e)}"
@tool
def write_to_markdown(content: str, filename: str) -> None:
"""
Write the given content to a markdown file.
Args:
content: The text content to write to the file
filename: The name of the markdown file (will automatically add .md extension if not present)
Example:
>>> write_to_markdown("# Hello World\\nThis is a test.", "output.md")
"""
# Add .md extension if not present
if not filename.endswith('.md'):
filename += '.md'
# Write content to file
with open(filename, 'w', encoding='utf-8') as f:
f.write(content)
print(f"Content successfully written to {filename}")
# Instantiate the DuckDuckGoSearchTool
search_tool = DuckDuckGoSearchTool(max_results=5, rate_limit=2.0)
# Instantiate the FinalAnswerTool
final_answer = FinalAnswerTool()
# Check if Ollama is available and configure the model accordingly
if is_ollama_available():
print("Ollama detected - using LiteLLMModel with local Ollama instance")
model = LiteLLMModel(
model_id=f"ollama_chat/{OLLAMA_MODEL_ID}", # Adjust model name based on what you have in Ollama
api_base=OLLAMA_BASE_URL,
api_key="ollama",
num_ctx=8192, # Important: Ollama's default 2048 may cause failures
max_tokens=2096,
temperature=0.5,
)
else:
print("Ollama not available - falling back to InferenceClientModel")
# If the agent does not answer, the model is overloaded, please use another model or the following Hugging Face Endpoint that also contains qwen2.5 coder:
# model_id='https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud'
model = InferenceClientModel(
max_tokens=2096,
temperature=0.5,
model_id=HF_MODEL_ID, # it is possible that this model may be overloaded
custom_role_conversions=None,
)
with open("prompts.yaml", 'r') as stream:
prompt_templates = yaml.safe_load(stream)
agent = CodeAgent(
model=model,
tools=[
final_answer,
search_tool,
get_current_time_in_timezone,
write_to_markdown,
], ## add your tools here (don't remove final answer)
max_steps=6,
verbosity_level=1,
planning_interval=None,
name=None,
description=None,
prompt_templates=prompt_templates
)
GradioUI(agent).launch()