|
|
from typing import Dict |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
"gpt2", torch_dtype=torch.float16, output_hidden_states=True |
|
|
) |
|
|
self.model = self.model.cuda() |
|
|
self.tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
|
|
|
|
|
|
def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]: |
|
|
""" |
|
|
Args: |
|
|
data (:obj:): |
|
|
includes the deserialized audio file as bytes |
|
|
Return: |
|
|
A :obj:`dict`:. base64 encoded image |
|
|
""" |
|
|
|
|
|
inputs = data.pop("inputs", data) |
|
|
all_logits = [] |
|
|
|
|
|
for doc in inputs: |
|
|
tokenized = self.tokenizer( |
|
|
inputs, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
) |
|
|
token_ids, token_mask = tokenized.input_ids.cuda(), tokenized.attention_mask.cuda() |
|
|
with torch.no_grad(): |
|
|
out = self.model(token_ids, attention_mask=token_mask) |
|
|
meaned_logits = (out.logits * token_mask.unsqueeze(-1)).sum(1) / token_mask.sum( |
|
|
1 |
|
|
).unsqueeze(-1) |
|
|
sorted_logits = torch.sort(out.logits).values |
|
|
mean_sorted_logits = (sorted_logits * token_mask.unsqueeze(-1)).sum( |
|
|
1 |
|
|
) / token_mask.sum(1).unsqueeze(-1) |
|
|
all_logits.append(meaned_logits.cpu().numpy().tolist()) |
|
|
|
|
|
|
|
|
return {"logits": all_logits} |