Spaces:
Running
Running
File size: 1,555 Bytes
8c66169 3a3b216 acd7cf4 3a3b216 acd7cf4 8c66169 acd7cf4 8c66169 acd7cf4 8c66169 3a3b216 acd7cf4 8c66169 3a3b216 acd7cf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
from abc import ABC, abstractmethod
from typing import List, Optional
from graphgen.bases import Token
class TopkTokenModel(ABC):
def __init__(
self,
do_sample: bool = False,
temperature: float = 0,
max_tokens: int = 4096,
repetition_penalty: float = 1.05,
num_beams: int = 1,
topk: int = 50,
topp: float = 0.95,
topk_per_token: int = 5,
):
self.do_sample = do_sample
self.temperature = temperature
self.max_tokens = max_tokens
self.repetition_penalty = repetition_penalty
self.num_beams = num_beams
self.topk = topk
self.topp = topp
self.topk_per_token = topk_per_token
@abstractmethod
async def generate_topk_per_token(self, text: str) -> List[Token]:
"""
Generate prob, text and candidates for each token of the model's output.
This function is used to visualize the inference process.
"""
raise NotImplementedError
@abstractmethod
async def generate_inputs_prob(
self, text: str, history: Optional[List[str]] = None
) -> List[Token]:
"""
Generate prob and text for each token of the input text.
This function is used to visualize the ppl.
"""
raise NotImplementedError
@abstractmethod
async def generate_answer(
self, text: str, history: Optional[List[str]] = None
) -> str:
"""
Generate answer from the model.
"""
raise NotImplementedError
|