File size: 2,501 Bytes
cd10708 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import torch
def predict_sentiment():
model_path = "./binary_model/checkpoint-400"
model_name = "cointegrated/rubert-tiny"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
clf = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
return_all_scores=False
)
def _inner(text: str):
pred = clf(text)
res = {
"labels": pred[0]["label"],
"probs": pred[0]["score"]
}
return res
return _inner
def predict_category():
model_path = "./category_model/checkpoint-400"
model_name = "cointegrated/rubert-tiny"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
clf = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer,
return_all_scores=False
)
def _inner(text: str):
pred = clf(text)
labels = {"политика": 0, "экономика": 0, "спорт": 0, "культура": 0, pred[0]["label"]: pred[0]["score"]}
classes = [
"политика",
"экономика",
"спорт",
"культура"
]
new_labels = []
for cl in classes:
new_labels.append(labels[cl])
res = {
"labels": classes,
"probs": new_labels
}
return res
return _inner
def predict_categorys():
model_path = "./multilabel_model/checkpoint-700"
model_name = "cointegrated/rubert-tiny"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
classes = [
"политика",
"экономика",
"спорт",
"культура"
]
def _inner(text: str):
input = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=256
)
with torch.no_grad():
logits = model(**input).logits
probs = torch.sigmoid(logits).squeeze().tolist()
res = {
"labels": classes,
"probs": probs
}
return res
return _inner |