GraphGen / graphgen /models /llm /api /ollama_client.py
github-actions[bot]
Auto-sync from demo at Wed Oct 29 11:25:28 UTC 2025
d02622b
raw
history blame
3.48 kB
from typing import Any, Dict, List, Optional
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.datatypes import Token
from graphgen.models.llm.limitter import RPM, TPM
class OllamaClient(BaseLLMWrapper):
"""
Requires a local or remote Ollama server to be running (default port 11434).
The top_logprobs field is not yet implemented by the official API.
"""
def __init__(
self,
*,
model: str = "gemma3",
base_url: str = "http://localhost:11434",
json_mode: bool = False,
seed: Optional[int] = None,
topk_per_token: int = 5,
request_limit: bool = False,
rpm: Optional[RPM] = None,
tpm: Optional[TPM] = None,
**kwargs: Any,
):
try:
import ollama
except ImportError as e:
raise ImportError(
"Ollama SDK is not installed."
"It is required to use OllamaClient."
"Please install it with `pip install ollama`."
) from e
super().__init__(**kwargs)
self.model_name = model
self.base_url = base_url
self.json_mode = json_mode
self.seed = seed
self.topk_per_token = topk_per_token
self.request_limit = request_limit
self.rpm = rpm or RPM()
self.tpm = tpm or TPM()
self.token_usage: List[Dict[str, int]] = []
self.client = ollama.AsyncClient(host=self.base_url)
async def generate_answer(
self,
text: str,
history: Optional[List[Dict[str, str]]] = None,
**extra: Any,
) -> str:
messages = []
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
if history:
messages.extend(history)
messages.append({"role": "user", "content": text})
options = {
"temperature": self.temperature,
"top_p": self.top_p,
"num_predict": self.max_tokens,
}
if self.seed is not None:
options["seed"] = self.seed
prompt_tokens = sum(len(self.tokenizer.encode(m["content"])) for m in messages)
est = prompt_tokens + self.max_tokens
if self.request_limit:
await self.rpm.wait(silent=True)
await self.tpm.wait(est, silent=True)
response = await self.client.chat(
model=self.model_name,
messages=messages,
format="json" if self.json_mode else "",
options=options,
stream=False,
)
usage = response.get("prompt_eval_count", 0), response.get("eval_count", 0)
self.token_usage.append(
{
"prompt_tokens": usage[0],
"completion_tokens": usage[1],
"total_tokens": sum(usage),
}
)
content = response["message"]["content"]
return self.filter_think_tags(content)
async def generate_topk_per_token(
self,
text: str,
history: Optional[List[Dict[str, str]]] = None,
**extra: Any,
) -> List[Token]:
raise NotImplementedError("Ollama API does not support per-token top-k yet.")
async def generate_inputs_prob(
self, text: str, history: Optional[List[Dict[str, str]]] = None, **extra: Any
) -> List[Token]:
raise NotImplementedError("Ollama API does not support per-token logprobs yet.")