File size: 3,482 Bytes
d02622b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from typing import Any, Dict, List, Optional

from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.datatypes import Token
from graphgen.models.llm.limitter import RPM, TPM


class OllamaClient(BaseLLMWrapper):
    """
    Requires a local or remote Ollama server to be running (default port 11434).
    The top_logprobs field is not yet implemented by the official API.
    """

    def __init__(
        self,
        *,
        model: str = "gemma3",
        base_url: str = "http://localhost:11434",
        json_mode: bool = False,
        seed: Optional[int] = None,
        topk_per_token: int = 5,
        request_limit: bool = False,
        rpm: Optional[RPM] = None,
        tpm: Optional[TPM] = None,
        **kwargs: Any,
    ):
        try:
            import ollama
        except ImportError as e:
            raise ImportError(
                "Ollama SDK is not installed."
                "It is required to use OllamaClient."
                "Please install it with `pip install ollama`."
            ) from e
        super().__init__(**kwargs)
        self.model_name = model
        self.base_url = base_url
        self.json_mode = json_mode
        self.seed = seed
        self.topk_per_token = topk_per_token
        self.request_limit = request_limit
        self.rpm = rpm or RPM()
        self.tpm = tpm or TPM()
        self.token_usage: List[Dict[str, int]] = []

        self.client = ollama.AsyncClient(host=self.base_url)

    async def generate_answer(
        self,
        text: str,
        history: Optional[List[Dict[str, str]]] = None,
        **extra: Any,
    ) -> str:
        messages = []
        if self.system_prompt:
            messages.append({"role": "system", "content": self.system_prompt})
        if history:
            messages.extend(history)
        messages.append({"role": "user", "content": text})

        options = {
            "temperature": self.temperature,
            "top_p": self.top_p,
            "num_predict": self.max_tokens,
        }
        if self.seed is not None:
            options["seed"] = self.seed

        prompt_tokens = sum(len(self.tokenizer.encode(m["content"])) for m in messages)
        est = prompt_tokens + self.max_tokens
        if self.request_limit:
            await self.rpm.wait(silent=True)
            await self.tpm.wait(est, silent=True)

        response = await self.client.chat(
            model=self.model_name,
            messages=messages,
            format="json" if self.json_mode else "",
            options=options,
            stream=False,
        )

        usage = response.get("prompt_eval_count", 0), response.get("eval_count", 0)
        self.token_usage.append(
            {
                "prompt_tokens": usage[0],
                "completion_tokens": usage[1],
                "total_tokens": sum(usage),
            }
        )
        content = response["message"]["content"]
        return self.filter_think_tags(content)

    async def generate_topk_per_token(
        self,
        text: str,
        history: Optional[List[Dict[str, str]]] = None,
        **extra: Any,
    ) -> List[Token]:
        raise NotImplementedError("Ollama API does not support per-token top-k yet.")

    async def generate_inputs_prob(
        self, text: str, history: Optional[List[Dict[str, str]]] = None, **extra: Any
    ) -> List[Token]:
        raise NotImplementedError("Ollama API does not support per-token logprobs yet.")