File size: 504 Bytes
63d4ab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class BaseModel:
    def infer(self,
            prompts,
            top_p=0.95,
            temperature=0.3,
            repetition_penalty=1.2):
        '''
        Takes a list of prompts and returns the output hidden states
        '''
        pass

    def stream_infer(self,
            prompt,
            top_p=0.95,
            temperature=0.3,
            repetition_penalty=1.2):
        '''
        Takes a prompt and returns an iterator of the output hidden states
        '''
        pass