Spaces:
Running
Running
File size: 4,786 Bytes
d02622b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from typing import Any, List, Optional
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.datatypes import Token
class VLLMWrapper(BaseLLMWrapper):
"""
Async inference backend based on vLLM (https://github.com/vllm-project/vllm)
"""
def __init__(
self,
model: str,
tensor_parallel_size: int = 1,
gpu_memory_utilization: float = 0.9,
temperature: float = 0.0,
top_p: float = 1.0,
topk: int = 5,
**kwargs: Any,
):
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
try:
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
except ImportError as exc:
raise ImportError(
"VLLMWrapper requires vllm. Install it with: uv pip install vllm --torch-backend=auto"
) from exc
self.SamplingParams = SamplingParams
engine_args = AsyncEngineArgs(
model=model,
tensor_parallel_size=tensor_parallel_size,
gpu_memory_utilization=gpu_memory_utilization,
trust_remote_code=kwargs.get("trust_remote_code", True),
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
self.temperature = temperature
self.top_p = top_p
self.topk = topk
@staticmethod
def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
msgs = history or []
lines = []
for m in msgs:
if isinstance(m, dict):
role = m.get("role", "")
content = m.get("content", "")
lines.append(f"{role}: {content}")
else:
lines.append(str(m))
lines.append(prompt)
return "\n".join(lines)
async def generate_answer(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> str:
full_prompt = self._build_inputs(text, history)
sp = self.SamplingParams(
temperature=self.temperature if self.temperature > 0 else 1.0,
top_p=self.top_p if self.temperature > 0 else 1.0,
max_tokens=extra.get("max_new_tokens", 512),
)
results = []
async for req_output in self.engine.generate(
full_prompt, sp, request_id="graphgen_req"
):
results = req_output.outputs
return results[-1].text
async def generate_topk_per_token(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
full_prompt = self._build_inputs(text, history)
sp = self.SamplingParams(
temperature=0,
max_tokens=1,
logprobs=self.topk,
)
results = []
async for req_output in self.engine.generate(
full_prompt, sp, request_id="graphgen_topk"
):
results = req_output.outputs
top_logprobs = results[-1].logprobs[0]
tokens = []
for _, logprob_obj in top_logprobs.items():
tok_str = logprob_obj.decoded_token
prob = float(logprob_obj.logprob.exp())
tokens.append(Token(tok_str, prob))
tokens.sort(key=lambda x: -x.prob)
return tokens
async def generate_inputs_prob(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
full_prompt = self._build_inputs(text, history)
# vLLM 没有现成的“mask 一个 token 再算 prob”接口,
# 我们采用最直观的方式:把 prompt 一次性送进去,打开
# prompt_logprobs=True,让 vLLM 返回 *输入部分* 每个位置的
# logprob,然后挑出对应 token 的概率即可。
sp = self.SamplingParams(
temperature=0,
max_tokens=0, # 不生成新 token
prompt_logprobs=1, # 只要 top-1 就够了
)
results = []
async for req_output in self.engine.generate(
full_prompt, sp, request_id="graphgen_prob"
):
results = req_output.outputs
# prompt_logprobs 是一个 list,长度 = prompt token 数,
# 每个元素是 dict{token_id: logprob_obj} 或 None(首个位置为 None)
prompt_logprobs = results[-1].prompt_logprobs
tokens = []
for _, logprob_dict in enumerate(prompt_logprobs):
if logprob_dict is None:
continue
# 这里每个 dict 只有 1 个 kv,因为 top-1
_, logprob_obj = next(iter(logprob_dict.items()))
tok_str = logprob_obj.decoded_token
prob = float(logprob_obj.logprob.exp())
tokens.append(Token(tok_str, prob))
return tokens
|