File size: 1,555 Bytes
8c66169
3a3b216
acd7cf4
3a3b216
acd7cf4
 
8c66169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acd7cf4
8c66169
acd7cf4
 
 
 
 
 
 
8c66169
3a3b216
 
 
acd7cf4
 
 
 
 
 
8c66169
3a3b216
 
 
acd7cf4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from abc import ABC, abstractmethod
from typing import List, Optional

from graphgen.bases import Token


class TopkTokenModel(ABC):
    def __init__(
        self,
        do_sample: bool = False,
        temperature: float = 0,
        max_tokens: int = 4096,
        repetition_penalty: float = 1.05,
        num_beams: int = 1,
        topk: int = 50,
        topp: float = 0.95,
        topk_per_token: int = 5,
    ):
        self.do_sample = do_sample
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.repetition_penalty = repetition_penalty
        self.num_beams = num_beams
        self.topk = topk
        self.topp = topp
        self.topk_per_token = topk_per_token

    @abstractmethod
    async def generate_topk_per_token(self, text: str) -> List[Token]:
        """
        Generate prob, text and candidates for each token of the model's output.
        This function is used to visualize the inference process.
        """
        raise NotImplementedError

    @abstractmethod
    async def generate_inputs_prob(
        self, text: str, history: Optional[List[str]] = None
    ) -> List[Token]:
        """
        Generate prob and text for each token of the input text.
        This function is used to visualize the ppl.
        """
        raise NotImplementedError

    @abstractmethod
    async def generate_answer(
        self, text: str, history: Optional[List[str]] = None
    ) -> str:
        """
        Generate answer from the model.
        """
        raise NotImplementedError