File size: 4,995 Bytes
d02622b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from typing import Any, List, Optional

from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
from graphgen.bases.datatypes import Token


class HuggingFaceWrapper(BaseLLMWrapper):
    """
    Async inference backend based on HuggingFace Transformers
    """

    def __init__(
        self,
        model: str,
        torch_dtype="auto",
        device_map="auto",
        trust_remote_code=True,
        temperature=0.0,
        top_p=1.0,
        topk=5,
        **kwargs: Any,
    ):
        super().__init__(temperature=temperature, top_p=top_p, **kwargs)

        try:
            import torch
            from transformers import (
                AutoModelForCausalLM,
                AutoTokenizer,
                GenerationConfig,
            )
        except ImportError as exc:
            raise ImportError(
                "HuggingFaceWrapper requires torch, transformers and accelerate. "
                "Install them with:  pip install torch transformers accelerate"
            ) from exc

        self.torch = torch
        self.AutoTokenizer = AutoTokenizer
        self.AutoModelForCausalLM = AutoModelForCausalLM
        self.GenerationConfig = GenerationConfig

        self.tokenizer = AutoTokenizer.from_pretrained(
            model, trust_remote_code=trust_remote_code
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModelForCausalLM.from_pretrained(
            model,
            torch_dtype=torch_dtype,
            device_map=device_map,
            trust_remote_code=trust_remote_code,
        )
        self.model.eval()
        self.temperature = temperature
        self.top_p = top_p
        self.topk = topk

    @staticmethod
    def _build_inputs(prompt: str, history: Optional[List[str]] = None) -> str:
        msgs = history or []
        lines = []
        for m in msgs:
            if isinstance(m, dict):
                role = m.get("role", "")
                content = m.get("content", "")
                lines.append(f"{role}: {content}")
            else:
                lines.append(str(m))
        lines.append(prompt)
        return "\n".join(lines)

    async def generate_answer(
        self, text: str, history: Optional[List[str]] = None, **extra: Any
    ) -> str:
        full = self._build_inputs(text, history)
        inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device)

        gen_kwargs = {
            "max_new_tokens": extra.get("max_new_tokens", 512),
            "do_sample": self.temperature > 0,
            "temperature": self.temperature if self.temperature > 0 else 1.0,
            "pad_token_id": self.tokenizer.eos_token_id,
        }

        # Add top_p and top_k only if temperature > 0
        if self.temperature > 0:
            gen_kwargs.update(top_p=self.top_p, top_k=self.topk)

        gen_config = self.GenerationConfig(**gen_kwargs)

        with self.torch.no_grad():
            out = self.model.generate(**inputs, generation_config=gen_config)

        gen = out[0, inputs.input_ids.shape[-1] :]
        return self.tokenizer.decode(gen, skip_special_tokens=True)

    async def generate_topk_per_token(
        self, text: str, history: Optional[List[str]] = None, **extra: Any
    ) -> List[Token]:
        full = self._build_inputs(text, history)
        inputs = self.tokenizer(full, return_tensors="pt").to(self.model.device)

        with self.torch.no_grad():
            out = self.model.generate(
                **inputs,
                max_new_tokens=1,
                do_sample=False,
                temperature=1.0,
                return_dict_in_generate=True,
                output_scores=True,
                pad_token_id=self.tokenizer.eos_token_id,
            )

        scores = out.scores[0][0]  # (vocab,)
        probs = self.torch.softmax(scores, dim=-1)
        top_probs, top_idx = self.torch.topk(probs, k=self.topk)

        tokens = []
        for p, idx in zip(top_probs.cpu().numpy(), top_idx.cpu().numpy()):
            tokens.append(Token(self.tokenizer.decode([idx]), float(p)))
        return tokens

    async def generate_inputs_prob(
        self, text: str, history: Optional[List[str]] = None, **extra: Any
    ) -> List[Token]:
        full = self._build_inputs(text, history)
        ids = self.tokenizer.encode(full)
        logprobs = []

        for i in range(1, len(ids) + 1):
            trunc = ids[: i - 1] + ids[i:] if i < len(ids) else ids[:-1]
            inputs = self.torch.tensor([trunc]).to(self.model.device)

            with self.torch.no_grad():
                logits = self.model(inputs).logits[0, -1, :]
            probs = self.torch.softmax(logits, dim=-1)

            true_id = ids[i - 1]
            logprobs.append(
                Token(
                    self.tokenizer.decode([true_id]),
                    float(probs[true_id].cpu()),
                )
            )
        return logprobs