github-actions[bot]
Auto-sync from demo at Wed Oct 29 11:25:28 UTC 2025
d02622b
raw
history blame
5 kB
from typing import Any, List, Optional
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.datatypes import Token
class HuggingFaceWrapper(BaseLLMWrapper):
"""
Async inference backend based on HuggingFace Transformers
"""
def __init__(
self,
model: str,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True,
temperature=0.0,
top_p=1.0,
topk=5,
**kwargs: Any,
):
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
try:
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
)
except ImportError as exc:
raise ImportError(
"HuggingFaceWrapper requires torch, transformers and accelerate. "
"Install them with: pip install torch transformers accelerate"
) from exc
self.torch = torch
self.AutoTokenizer = AutoTokenizer
self.AutoModelForCausalLM = AutoModelForCausalLM
self.GenerationConfig = GenerationConfig
self.tokenizer = AutoTokenizer.from_pretrained(
model, trust_remote_code=trust_remote_code
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype=torch_dtype,
device_map=device_map,
trust_remote_code=trust_remote_code,
)
self.model.eval()
self.temperature = temperature
self.top_p = top_p
self.topk = topk
@staticmethod
def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
msgs = history or []
lines = []
for m in msgs:
if isinstance(m, dict):
role = m.get("role", "")
content = m.get("content", "")
lines.append(f"{role}: {content}")
else:
lines.append(str(m))
lines.append(prompt)
return "\n".join(lines)
async def generate_answer(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> str:
full = self._build_inputs(text, history)
inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device)
gen_kwargs = {
"max_new_tokens": extra.get("max_new_tokens", 512),
"do_sample": self.temperature > 0,
"temperature": self.temperature if self.temperature > 0 else 1.0,
"pad_token_id": self.tokenizer.eos_token_id,
}
# Add top_p and top_k only if temperature > 0
if self.temperature > 0:
gen_kwargs.update(top_p=self.top_p, top_k=self.topk)
gen_config = self.GenerationConfig(**gen_kwargs)
with self.torch.no_grad():
out = self.model.generate(**inputs, generation_config=gen_config)
gen = out[0, inputs.input_ids.shape[-1] :]
return self.tokenizer.decode(gen, skip_special_tokens=True)
async def generate_topk_per_token(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
full = self._build_inputs(text, history)
inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device)
with self.torch.no_grad():
out = self.model.generate(
**inputs,
max_new_tokens=1,
do_sample=False,
temperature=1.0,
return_dict_in_generate=True,
output_scores=True,
pad_token_id=self.tokenizer.eos_token_id,
)
scores = out.scores[0][0] # (vocab,)
probs = self.torch.softmax(scores, dim=-1)
top_probs, top_idx = self.torch.topk(probs, k=self.topk)
tokens = []
for p, idx in zip(top_probs.cpu().numpy(), top_idx.cpu().numpy()):
tokens.append(Token(self.tokenizer.decode([idx]), float(p)))
return tokens
async def generate_inputs_prob(
self, text: str, history: Optional[List[str]] = None, **extra: Any
) -> List[Token]:
full = self._build_inputs(text, history)
ids = self.tokenizer.encode(full)
logprobs = []
for i in range(1, len(ids) + 1):
trunc = ids[: i - 1] + ids[i:] if i < len(ids) else ids[:-1]
inputs = self.torch.tensor([trunc]).to(self.model.device)
with self.torch.no_grad():
logits = self.model(inputs).logits[0, -1, :]
probs = self.torch.softmax(logits, dim=-1)
true_id = ids[i - 1]
logprobs.append(
Token(
self.tokenizer.decode([true_id]),
float(probs[true_id].cpu()),
)
)
return logprobs