Spaces:
Sleeping
Sleeping
File size: 2,594 Bytes
3a3b216 d02622b 3a3b216 3226ae7 3a3b216 3226ae7 d02622b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from __future__ import annotations
import abc
import re
from typing import Any, List, Optional
from graphgen.bases.base_tokenizer import BaseTokenizer
from graphgen.bases.datatypes import Token
class BaseLLMWrapper(abc.ABC):
"""
LLM client base class, agnostic to specific backends (OpenAI / Ollama / ...).
"""
def __init__(
self,
*,
system_prompt: str = "",
temperature: float = 0.0,
max_tokens: int = 4096,
repetition_penalty: float = 1.05,
top_p: float = 0.95,
top_k: int = 50,
tokenizer: Optional[BaseTokenizer] = None,
**kwargs: Any,
):
self.system_prompt = system_prompt
self.temperature = temperature
self.max_tokens = max_tokens
self.repetition_penalty = repetition_penalty
self.top_p = top_p
self.top_k = top_k
self.tokenizer = tokenizer
for k, v in kwargs.items():
setattr(self, k, v)
@abc.abstractmethod
async def generate_answer(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> str:
"""Generate answer from the model."""
raise NotImplementedError
@abc.abstractmethod
async def generate_topk_per_token(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
"""Generate top-k tokens for the next token prediction."""
raise NotImplementedError
@abc.abstractmethod
async def generate_inputs_prob(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
"""Generate probabilities for each token in the input."""
raise NotImplementedError
@staticmethod
def filter_think_tags(text: str, think_tag: str = "think") -> str:
"""
Remove <think> tags from the text.
- If the text contains <think> and </think>, it removes everything between them and the tags themselves.
- If the text contains only </think>, it removes content before the tag.
"""
paired_pattern = re.compile(rf"<{think_tag}>.*?</{think_tag}>", re.DOTALL)
filtered = paired_pattern.sub("", text)
orphan_pattern = re.compile(rf"^.*?</{think_tag}>", re.DOTALL)
filtered = orphan_pattern.sub("", filtered)
filtered = filtered.strip()
return filtered if filtered else text.strip()
def shutdown(self) -> None:
"""Shutdown the LLM engine if applicable."""
def restart(self) -> None:
"""Reinitialize the LLM engine if applicable."""
|