Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import abc | |
| import re | |
| from typing import Any, List, Optional | |
| from graphgen.bases.base_tokenizer import BaseTokenizer | |
| from graphgen.bases.datatypes import Token | |
| class BaseLLMWrapper(abc.ABC): | |
| """ | |
| LLM client base class, agnostic to specific backends (OpenAI / Ollama / ...). | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| system_prompt: str = "", | |
| temperature: float = 0.0, | |
| max_tokens: int = 4096, | |
| repetition_penalty: float = 1.05, | |
| top_p: float = 0.95, | |
| top_k: int = 50, | |
| tokenizer: Optional[BaseTokenizer] = None, | |
| **kwargs: Any, | |
| ): | |
| self.system_prompt = system_prompt | |
| self.temperature = temperature | |
| self.max_tokens = max_tokens | |
| self.repetition_penalty = repetition_penalty | |
| self.top_p = top_p | |
| self.top_k = top_k | |
| self.tokenizer = tokenizer | |
| for k, v in kwargs.items(): | |
| setattr(self, k, v) | |
| async def generate_answer( | |
| self, text: str, history: Optional[List[str]] = None, **extra: Any | |
| ) -> str: | |
| """Generate answer from the model.""" | |
| raise NotImplementedError | |
| async def generate_topk_per_token( | |
| self, text: str, history: Optional[List[str]] = None, **extra: Any | |
| ) -> List[Token]: | |
| """Generate top-k tokens for the next token prediction.""" | |
| raise NotImplementedError | |
| async def generate_inputs_prob( | |
| self, text: str, history: Optional[List[str]] = None, **extra: Any | |
| ) -> List[Token]: | |
| """Generate probabilities for each token in the input.""" | |
| raise NotImplementedError | |
| def filter_think_tags(text: str, think_tag: str = "think") -> str: | |
| """ | |
| Remove <think> tags from the text. | |
| - If the text contains <think> and </think>, it removes everything between them and the tags themselves. | |
| - If the text contains only </think>, it removes content before the tag. | |
| """ | |
| paired_pattern = re.compile(rf"<{think_tag}>.*?</{think_tag}>", re.DOTALL) | |
| filtered = paired_pattern.sub("", text) | |
| orphan_pattern = re.compile(rf"^.*?</{think_tag}>", re.DOTALL) | |
| filtered = orphan_pattern.sub("", filtered) | |
| filtered = filtered.strip() | |
| return filtered if filtered else text.strip() | |
| def shutdown(self) -> None: | |
| """Shutdown the LLM engine if applicable.""" | |
| def restart(self) -> None: | |
| """Reinitialize the LLM engine if applicable.""" | |