Spaces:
Running
Running
File size: 5,018 Bytes
d02622b 956a55c d02622b 956a55c d02622b 720bedd d02622b 956a55c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import math
from typing import Any, Dict, List, Optional
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.datatypes import Token
class SGLangWrapper(BaseLLMWrapper):
"""
Async inference backend based on SGLang offline engine.
"""
def __init__(
self,
model: str,
temperature: float = 0.0,
top_p: float = 1.0,
topk: int = 5,
tp_size: int = 1,
**kwargs: Any,
):
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
try:
import sglang as sgl
from sglang.utils import async_stream_and_merge, stream_and_merge
except ImportError as exc:
raise ImportError(
"SGLangWrapper requires sglang. Install it with: "
"uv pip install sglang --prerelease=allow"
) from exc
self.model_path: str = model
self.temperature = temperature
self.top_p = top_p
self.topk = topk
self.tp_size = int(tp_size)
# Initialise the offline engine
self.engine = sgl.Engine(model_path=self.model_path, tp_size=self.tp_size)
# Keep helpers for streaming
self.async_stream_and_merge = async_stream_and_merge
self.stream_and_merge = stream_and_merge
@staticmethod
def _build_sampling_params(
temperature: float,
top_p: float,
max_tokens: int,
topk: int,
logprobs: bool = False,
) -> Dict[str, Any]:
"""Build SGLang-compatible sampling-params dict."""
params = {
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_tokens,
}
if logprobs and topk > 0:
params["logprobs"] = topk
return params
def _prep_prompt(self, text: str, history: Optional[List[dict]] = None) -> str:
"""Convert raw text (+ optional history) into a single prompt string."""
parts = []
if self.system_prompt:
parts.append(self.system_prompt)
if history:
assert len(history) % 2 == 0, "History must have even length (u/a turns)."
parts.extend([item["content"] for item in history])
parts.append(text)
return "\n".join(parts)
def _tokens_from_output(self, output: Dict[str, Any]) -> List[Token]:
tokens: List[Token] = []
meta = output.get("meta_info", {})
logprobs = meta.get("output_token_logprobs", [])
topks = meta.get("output_top_logprobs", [])
tokenizer = self.engine.tokenizer_manager.tokenizer
for idx, (lp, tid, _) in enumerate(logprobs):
prob = math.exp(lp)
tok_str = tokenizer.decode([tid])
top_candidates = []
if self.topk > 0 and idx < len(topks):
for t_lp, t_tid, _ in topks[idx][: self.topk]:
top_candidates.append(
Token(text=tokenizer.decode([t_tid]), prob=math.exp(t_lp))
)
tokens.append(Token(text=tok_str, prob=prob, top_candidates=top_candidates))
return tokens
async def generate_answer(
self,
text: str,
history: Optional[List[str]] = None,
**extra: Any,
) -> str:
prompt = self._prep_prompt(text, history)
sampling_params = self._build_sampling_params(
temperature=self.temperature,
top_p=self.top_p,
max_tokens=self.max_tokens,
topk=0, # no logprobs needed for simple generation
)
outputs = await self.engine.async_generate([prompt], sampling_params)
return self.filter_think_tags(outputs[0]["text"])
async def generate_topk_per_token(
self,
text: str,
history: Optional[List[str]] = None,
**extra: Any,
) -> List[Token]:
prompt = self._prep_prompt(text, history)
sampling_params = self._build_sampling_params(
temperature=self.temperature,
top_p=self.top_p,
max_tokens=1, # keep short for token-level analysis
topk=self.topk,
)
outputs = await self.engine.async_generate(
[prompt], sampling_params, return_logprob=True, top_logprobs_num=5
)
print(outputs)
return self._tokens_from_output(outputs[0])
async def generate_inputs_prob(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
raise NotImplementedError(
"SGLangWrapper does not support per-token logprobs yet."
)
def shutdown(self) -> None:
"""Gracefully shutdown the SGLang engine."""
if hasattr(self, "engine"):
self.engine.shutdown()
def restart(self) -> None:
"""Restart the SGLang engine."""
self.shutdown()
self.engine = self.engine.__class__(
model_path=self.model_path, tp_size=self.tp_size
)
|