File size: 4,582 Bytes
ec8f374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
Ollama Integration - Local model serving and HuggingFace mapping

Provides integration with Ollama for local model inference and maps
Ollama models to their HuggingFace equivalents for training.
"""

import requests
from typing import Optional, Dict, List


class OllamaClient:
    """Client for interacting with Ollama API"""

    def __init__(self, base_url: str = "http://localhost:11434"):
        self.base_url = base_url

    def generate(self, model: str, prompt: str, **kwargs) -> Dict:
        """Generate text using Ollama model"""
        url = f"{self.base_url}/api/generate"

        payload = {
            "model": model,
            "prompt": prompt,
            "stream": False,
            **kwargs
        }

        try:
            response = requests.post(url, json=payload, timeout=120)
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            return {"error": str(e), "response": ""}

    def chat(self, model: str, messages: List[Dict], **kwargs) -> Dict:
        """Chat with Ollama model"""
        url = f"{self.base_url}/api/chat"

        payload = {
            "model": model,
            "messages": messages,
            "stream": False,
            **kwargs
        }

        try:
            response = requests.post(url, json=payload, timeout=120)
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            return {"error": str(e), "message": {"content": ""}}

    def list_models(self) -> List[str]:
        """List available Ollama models"""
        url = f"{self.base_url}/api/tags"

        try:
            response = requests.get(url, timeout=10)
            response.raise_for_status()
            data = response.json()
            return [model["name"] for model in data.get("models", [])]
        except requests.exceptions.RequestException:
            return []

    def is_available(self) -> bool:
        """Check if Ollama is running"""
        try:
            response = requests.get(self.base_url, timeout=5)
            return response.status_code == 200
        except requests.exceptions.RequestException:
            return False

    def model_exists(self, model: str) -> bool:
        """Check if a specific model is available"""
        return model in self.list_models()


# Mapping of Ollama models to HuggingFace equivalents
OLLAMA_TO_HF_MAP = {
    "qwen2.5:7b": "Qwen/Qwen2.5-7B-Instruct",
    "qwen2.5:32b": "Qwen/Qwen2.5-32B-Instruct",
    "llama3.1:8b": "meta-llama/Meta-Llama-3.1-8B-Instruct",
    "llama3.1:70b": "meta-llama/Meta-Llama-3.1-70B-Instruct",
    "mistral:7b": "mistralai/Mistral-7B-Instruct-v0.3",
    "mixtral:8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "phi3:mini": "microsoft/Phi-3-mini-4k-instruct",
    "gemma:7b": "google/gemma-7b-it",
}


def get_hf_model_for_ollama(ollama_model: str) -> Optional[str]:
    """Get the HuggingFace model ID for an Ollama model"""
    # Handle versioned models (e.g., "qwen2.5:7b-instruct" -> "qwen2.5:7b")
    base_model = ollama_model.split("-")[0] if "-" in ollama_model else ollama_model

    return OLLAMA_TO_HF_MAP.get(base_model)


def get_ollama_for_hf(hf_model: str) -> Optional[str]:
    """Get the Ollama model for a HuggingFace model ID"""
    reverse_map = {v: k for k, v in OLLAMA_TO_HF_MAP.items()}
    return reverse_map.get(hf_model)


def test_financial_advisor_ollama(
    model: str = "qwen2.5:7b",
    ollama_client: Optional[OllamaClient] = None
) -> str:
    """Test a financial advisor model via Ollama"""

    if ollama_client is None:
        ollama_client = OllamaClient()

    if not ollama_client.is_available():
        return "Error: Ollama is not running. Please start Ollama first."

    if not ollama_client.model_exists(model):
        return f"Error: Model '{model}' is not available. Please pull it first with: ollama pull {model}"

    # Test prompt
    prompt = """You are a financial advisor. A client asks:

"I'm 35 years old with $50,000 in savings. Should I invest in stocks or bonds?"

Provide professional financial advice."""

    result = ollama_client.generate(model, prompt)

    if "error" in result:
        return f"Error: {result['error']}"

    return result.get("response", "No response")


# Global client instance
_client = None


def get_ollama_client() -> OllamaClient:
    """Get the global Ollama client instance"""
    global _client
    if _client is None:
        _client = OllamaClient()
    return _client