GraphGen / graphgen /models /llm /local /sglang_wrapper.py
github-actions[bot]
Auto-sync from demo at Wed Nov 12 11:19:10 UTC 2025
720bedd
raw
history blame
5.02 kB
import math
from typing import Any, Dict, List, Optional
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.datatypes import Token
class SGLangWrapper(BaseLLMWrapper):
"""
Async inference backend based on SGLang offline engine.
"""
def __init__(
self,
model: str,
temperature: float = 0.0,
top_p: float = 1.0,
topk: int = 5,
tp_size: int = 1,
**kwargs: Any,
):
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
try:
import sglang as sgl
from sglang.utils import async_stream_and_merge, stream_and_merge
except ImportError as exc:
raise ImportError(
"SGLangWrapper requires sglang. Install it with: "
"uv pip install sglang --prerelease=allow"
) from exc
self.model_path: str = model
self.temperature = temperature
self.top_p = top_p
self.topk = topk
self.tp_size = int(tp_size)
# Initialise the offline engine
self.engine = sgl.Engine(model_path=self.model_path, tp_size=self.tp_size)
# Keep helpers for streaming
self.async_stream_and_merge = async_stream_and_merge
self.stream_and_merge = stream_and_merge
@staticmethod
def _build_sampling_params(
temperature: float,
top_p: float,
max_tokens: int,
topk: int,
logprobs: bool = False,
) -> Dict[str, Any]:
"""Build SGLang-compatible sampling-params dict."""
params = {
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_tokens,
}
if logprobs and topk > 0:
params["logprobs"] = topk
return params
def _prep_prompt(self, text: str, history: Optional[List[dict]] = None) -> str:
"""Convert raw text (+ optional history) into a single prompt string."""
parts = []
if self.system_prompt:
parts.append(self.system_prompt)
if history:
assert len(history) % 2 == 0, "History must have even length (u/a turns)."
parts.extend([item["content"] for item in history])
parts.append(text)
return "\n".join(parts)
def _tokens_from_output(self, output: Dict[str, Any]) -> List[Token]:
tokens: List[Token] = []
meta = output.get("meta_info", {})
logprobs = meta.get("output_token_logprobs", [])
topks = meta.get("output_top_logprobs", [])
tokenizer = self.engine.tokenizer_manager.tokenizer
for idx, (lp, tid, _) in enumerate(logprobs):
prob = math.exp(lp)
tok_str = tokenizer.decode([tid])
top_candidates = []
if self.topk > 0 and idx < len(topks):
for t_lp, t_tid, _ in topks[idx][: self.topk]:
top_candidates.append(
Token(text=tokenizer.decode([t_tid]), prob=math.exp(t_lp))
)
tokens.append(Token(text=tok_str, prob=prob, top_candidates=top_candidates))
return tokens
async def generate_answer(
self,
text: str,
history: Optional[List[str]] = None,
**extra: Any,
) -> str:
prompt = self._prep_prompt(text, history)
sampling_params = self._build_sampling_params(
temperature=self.temperature,
top_p=self.top_p,
max_tokens=self.max_tokens,
topk=0, # no logprobs needed for simple generation
)
outputs = await self.engine.async_generate([prompt], sampling_params)
return self.filter_think_tags(outputs[0]["text"])
async def generate_topk_per_token(
self,
text: str,
history: Optional[List[str]] = None,
**extra: Any,
) -> List[Token]:
prompt = self._prep_prompt(text, history)
sampling_params = self._build_sampling_params(
temperature=self.temperature,
top_p=self.top_p,
max_tokens=1, # keep short for token-level analysis
topk=self.topk,
)
outputs = await self.engine.async_generate(
[prompt], sampling_params, return_logprob=True, top_logprobs_num=5
)
print(outputs)
return self._tokens_from_output(outputs[0])
async def generate_inputs_prob(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
raise NotImplementedError(
"SGLangWrapper does not support per-token logprobs yet."
)
def shutdown(self) -> None:
"""Gracefully shutdown the SGLang engine."""
if hasattr(self, "engine"):
self.engine.shutdown()
def restart(self) -> None:
"""Restart the SGLang engine."""
self.shutdown()
self.engine = self.engine.__class__(
model_path=self.model_path, tp_size=self.tp_size
)