File size: 5,018 Bytes
d02622b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
956a55c
d02622b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
956a55c
d02622b
 
720bedd
d02622b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
956a55c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import math
from typing import Any, Dict, List, Optional

from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.datatypes import Token


class SGLangWrapper(BaseLLMWrapper):
    """
    Async inference backend based on SGLang offline engine.
    """

    def __init__(
        self,
        model: str,
        temperature: float = 0.0,
        top_p: float = 1.0,
        topk: int = 5,
        tp_size: int = 1,
        **kwargs: Any,
    ):
        super().__init__(temperature=temperature, top_p=top_p, **kwargs)
        try:
            import sglang as sgl
            from sglang.utils import async_stream_and_merge, stream_and_merge
        except ImportError as exc:
            raise ImportError(
                "SGLangWrapper requires sglang. Install it with: "
                "uv pip install sglang --prerelease=allow"
            ) from exc

        self.model_path: str = model
        self.temperature = temperature
        self.top_p = top_p
        self.topk = topk
        self.tp_size = int(tp_size)

        # Initialise the offline engine
        self.engine = sgl.Engine(model_path=self.model_path, tp_size=self.tp_size)

        # Keep helpers for streaming
        self.async_stream_and_merge = async_stream_and_merge
        self.stream_and_merge = stream_and_merge

    @staticmethod
    def _build_sampling_params(
        temperature: float,
        top_p: float,
        max_tokens: int,
        topk: int,
        logprobs: bool = False,
    ) -> Dict[str, Any]:
        """Build SGLang-compatible sampling-params dict."""
        params = {
            "temperature": temperature,
            "top_p": top_p,
            "max_new_tokens": max_tokens,
        }
        if logprobs and topk > 0:
            params["logprobs"] = topk
        return params

    def _prep_prompt(self, text: str, history: Optional[List[dict]] = None) -> str:
        """Convert raw text (+ optional history) into a single prompt string."""
        parts = []
        if self.system_prompt:
            parts.append(self.system_prompt)
        if history:
            assert len(history) % 2 == 0, "History must have even length (u/a turns)."
            parts.extend([item["content"] for item in history])
        parts.append(text)
        return "\n".join(parts)

    def _tokens_from_output(self, output: Dict[str, Any]) -> List[Token]:
        tokens: List[Token] = []

        meta = output.get("meta_info", {})
        logprobs = meta.get("output_token_logprobs", [])
        topks = meta.get("output_top_logprobs", [])

        tokenizer = self.engine.tokenizer_manager.tokenizer

        for idx, (lp, tid, _) in enumerate(logprobs):
            prob = math.exp(lp)
            tok_str = tokenizer.decode([tid])

            top_candidates = []
            if self.topk > 0 and idx < len(topks):
                for t_lp, t_tid, _ in topks[idx][: self.topk]:
                    top_candidates.append(
                        Token(text=tokenizer.decode([t_tid]), prob=math.exp(t_lp))
                    )

            tokens.append(Token(text=tok_str, prob=prob, top_candidates=top_candidates))

        return tokens

    async def generate_answer(
        self,
        text: str,
        history: Optional[List[str]] = None,
        **extra: Any,
    ) -> str:
        prompt = self._prep_prompt(text, history)
        sampling_params = self._build_sampling_params(
            temperature=self.temperature,
            top_p=self.top_p,
            max_tokens=self.max_tokens,
            topk=0,  # no logprobs needed for simple generation
        )

        outputs = await self.engine.async_generate([prompt], sampling_params)
        return self.filter_think_tags(outputs[0]["text"])

    async def generate_topk_per_token(
        self,
        text: str,
        history: Optional[List[str]] = None,
        **extra: Any,
    ) -> List[Token]:
        prompt = self._prep_prompt(text, history)
        sampling_params = self._build_sampling_params(
            temperature=self.temperature,
            top_p=self.top_p,
            max_tokens=1,  # keep short for token-level analysis
            topk=self.topk,
        )

        outputs = await self.engine.async_generate(
            [prompt], sampling_params, return_logprob=True, top_logprobs_num=5
        )
        print(outputs)
        return self._tokens_from_output(outputs[0])

    async def generate_inputs_prob(
        self, text: str, history: Optional[List[str]] = None, **extra: Any
    ) -> List[Token]:
        raise NotImplementedError(
            "SGLangWrapper does not support per-token logprobs yet."
        )

    def shutdown(self) -> None:
        """Gracefully shutdown the SGLang engine."""
        if hasattr(self, "engine"):
            self.engine.shutdown()

    def restart(self) -> None:
        """Restart the SGLang engine."""
        self.shutdown()
        self.engine = self.engine.__class__(
            model_path=self.model_path, tp_size=self.tp_size
        )