from typing import Any, Dict, List, Optional from graphgen.bases.base_llm_wrapper import BaseLLMWrapper from graphgen.bases.datatypes import Token from graphgen.models.llm.limitter import RPM, TPM class OllamaClient(BaseLLMWrapper): """ Requires a local or remote Ollama server to be running (default port 11434). The top_logprobs field is not yet implemented by the official API. """ def __init__( self, *, model: str = "gemma3", base_url: str = "http://localhost:11434", json_mode: bool = False, seed: Optional[int] = None, topk_per_token: int = 5, request_limit: bool = False, rpm: Optional[RPM] = None, tpm: Optional[TPM] = None, **kwargs: Any, ): try: import ollama except ImportError as e: raise ImportError( "Ollama SDK is not installed." "It is required to use OllamaClient." "Please install it with `pip install ollama`." ) from e super().__init__(**kwargs) self.model_name = model self.base_url = base_url self.json_mode = json_mode self.seed = seed self.topk_per_token = topk_per_token self.request_limit = request_limit self.rpm = rpm or RPM() self.tpm = tpm or TPM() self.token_usage: List[Dict[str, int]] = [] self.client = ollama.AsyncClient(host=self.base_url) async def generate_answer( self, text: str, history: Optional[List[Dict[str, str]]] = None, **extra: Any, ) -> str: messages = [] if self.system_prompt: messages.append({"role": "system", "content": self.system_prompt}) if history: messages.extend(history) messages.append({"role": "user", "content": text}) options = { "temperature": self.temperature, "top_p": self.top_p, "num_predict": self.max_tokens, } if self.seed is not None: options["seed"] = self.seed prompt_tokens = sum(len(self.tokenizer.encode(m["content"])) for m in messages) est = prompt_tokens + self.max_tokens if self.request_limit: await self.rpm.wait(silent=True) await self.tpm.wait(est, silent=True) response = await self.client.chat( model=self.model_name, messages=messages, format="json" if self.json_mode else "", options=options, stream=False, ) usage = response.get("prompt_eval_count", 0), response.get("eval_count", 0) self.token_usage.append( { "prompt_tokens": usage[0], "completion_tokens": usage[1], "total_tokens": sum(usage), } ) content = response["message"]["content"] return self.filter_think_tags(content) async def generate_topk_per_token( self, text: str, history: Optional[List[Dict[str, str]]] = None, **extra: Any, ) -> List[Token]: raise NotImplementedError("Ollama API does not support per-token top-k yet.") async def generate_inputs_prob( self, text: str, history: Optional[List[Dict[str, str]]] = None, **extra: Any ) -> List[Token]: raise NotImplementedError("Ollama API does not support per-token logprobs yet.")