File size: 1,648 Bytes
826348d
 
 
 
 
 
 
 
 
8a6e1d9
826348d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
015cebe
826348d
0321587
826348d
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from typing import  Dict
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch



class EndpointHandler():
    def __init__(self, path=""):
        # load the model
        self.model = AutoModelForCausalLM.from_pretrained(
            "gpt2", torch_dtype=torch.float16, output_hidden_states=True
        )
        self.model = self.model.cuda()
        self.tokenizer = AutoTokenizer.from_pretrained("gpt2")


    def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
        """
        Args:
            data (:obj:):
                includes the deserialized audio file as bytes
        Return:
            A :obj:`dict`:. base64 encoded image
        """
        # process input
        inputs = data.pop("inputs", data)
        all_logits = []

        for doc in inputs:
          tokenized = self.tokenizer(
              inputs,
              return_tensors="pt",
              truncation=True,
              max_length=512,
          )
          token_ids, token_mask = tokenized.input_ids.cuda(), tokenized.attention_mask.cuda()
          with torch.no_grad():
              out = self.model(token_ids, attention_mask=token_mask)
          meaned_logits = (out.logits * token_mask.unsqueeze(-1)).sum(1) / token_mask.sum(
              1
          ).unsqueeze(-1)
          sorted_logits = torch.sort(out.logits).values
          mean_sorted_logits = (sorted_logits * token_mask.unsqueeze(-1)).sum(
              1
          ) / token_mask.sum(1).unsqueeze(-1)
          all_logits.append(meaned_logits.cpu().numpy().tolist())
        
        # postprocess the prediction
        return {"logits": all_logits}