handler_trial / handler.py
alexadam's picture
Update handler.py
0321587 verified
from typing import Dict
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class EndpointHandler():
def __init__(self, path=""):
# load the model
self.model = AutoModelForCausalLM.from_pretrained(
"gpt2", torch_dtype=torch.float16, output_hidden_states=True
)
self.model = self.model.cuda()
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
"""
Args:
data (:obj:):
includes the deserialized audio file as bytes
Return:
A :obj:`dict`:. base64 encoded image
"""
# process input
inputs = data.pop("inputs", data)
all_logits = []
for doc in inputs:
tokenized = self.tokenizer(
inputs,
return_tensors="pt",
truncation=True,
max_length=512,
)
token_ids, token_mask = tokenized.input_ids.cuda(), tokenized.attention_mask.cuda()
with torch.no_grad():
out = self.model(token_ids, attention_mask=token_mask)
meaned_logits = (out.logits * token_mask.unsqueeze(-1)).sum(1) / token_mask.sum(
1
).unsqueeze(-1)
sorted_logits = torch.sort(out.logits).values
mean_sorted_logits = (sorted_logits * token_mask.unsqueeze(-1)).sum(
1
) / token_mask.sum(1).unsqueeze(-1)
all_logits.append(meaned_logits.cpu().numpy().tolist())
# postprocess the prediction
return {"logits": all_logits}