Upload 2 files
Browse files- ebert_model.py +40 -0
- proverka.py +66 -0
ebert_model.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import BertConfig, BertModel
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class EBertConfig(BertConfig):
|
| 5 |
+
model_type = "ebert"
|
| 6 |
+
def __init__(self, **kwargs):
|
| 7 |
+
super().__init__(**kwargs)
|
| 8 |
+
self.adapter_size = kwargs.pop('adapter_size', None)
|
| 9 |
+
|
| 10 |
+
class EBertModel(BertModel):
|
| 11 |
+
config_class = EBertConfig
|
| 12 |
+
|
| 13 |
+
def __init__(self, config: EBertConfig):
|
| 14 |
+
super().__init__(config)
|
| 15 |
+
if config.adapter_size:
|
| 16 |
+
self.adapters = nn.ModuleList([
|
| 17 |
+
nn.Sequential(
|
| 18 |
+
nn.Linear(config.hidden_size, config.adapter_size),
|
| 19 |
+
nn.ReLU(),
|
| 20 |
+
nn.Linear(config.adapter_size, config.hidden_size),
|
| 21 |
+
)
|
| 22 |
+
for _ in range(config.num_hidden_layers)
|
| 23 |
+
])
|
| 24 |
+
else:
|
| 25 |
+
self.adapters = None
|
| 26 |
+
|
| 27 |
+
def forward(self, *args, **kwargs):
|
| 28 |
+
outputs = super().forward(*args, **kwargs)
|
| 29 |
+
sequence_output = outputs.last_hidden_state
|
| 30 |
+
|
| 31 |
+
if self.adapters is not None:
|
| 32 |
+
for adapter in self.adapters:
|
| 33 |
+
sequence_output = sequence_output + adapter(sequence_output)
|
| 34 |
+
|
| 35 |
+
return outputs.__class__(
|
| 36 |
+
last_hidden_state=sequence_output,
|
| 37 |
+
pooler_output=outputs.pooler_output,
|
| 38 |
+
hidden_states=outputs.hidden_states,
|
| 39 |
+
attentions=outputs.attentions,
|
| 40 |
+
)
|
proverka.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import BertTokenizerFast, BertConfig
|
| 2 |
+
import torch
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
from ebert_model import EBertConfig, EBertModel
|
| 5 |
+
from transformers import BertForMaskedLM
|
| 6 |
+
from safetensors.torch import load_file
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
model_path = "./ebert"
|
| 10 |
+
tokenizer = BertTokenizerFast.from_pretrained(model_path)
|
| 11 |
+
|
| 12 |
+
config = EBertConfig.from_pretrained(model_path)
|
| 13 |
+
model = BertForMaskedLM(config)
|
| 14 |
+
model.bert = EBertModel(config)
|
| 15 |
+
|
| 16 |
+
weights_path = f"{model_path}/model.safetensors"
|
| 17 |
+
if os.path.exists(weights_path):
|
| 18 |
+
state_dict = load_file(weights_path)
|
| 19 |
+
model.load_state_dict(state_dict, strict=False)
|
| 20 |
+
else:
|
| 21 |
+
raise FileNotFoundError(f"Файл {weights_path} не найден.")
|
| 22 |
+
|
| 23 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
+
model.to(device)
|
| 25 |
+
model.eval()
|
| 26 |
+
dataset = load_dataset("Expotion/russian-facts-qa", split="train")
|
| 27 |
+
|
| 28 |
+
def predict_masked_text(example):
|
| 29 |
+
text = f"{example['q'].strip()} {example['a'].strip()}"
|
| 30 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding="max_length")
|
| 31 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 32 |
+
|
| 33 |
+
input_ids = inputs["input_ids"].clone()
|
| 34 |
+
labels = input_ids.clone()
|
| 35 |
+
mask_token_index = 2
|
| 36 |
+
if input_ids.size(1) <= mask_token_index:
|
| 37 |
+
return {
|
| 38 |
+
"original_text": text,
|
| 39 |
+
"masked_text": "Слишком короткий текст",
|
| 40 |
+
"predicted_tokens": [],
|
| 41 |
+
"true_token": ""
|
| 42 |
+
}
|
| 43 |
+
input_ids[0, mask_token_index] = tokenizer.mask_token_id
|
| 44 |
+
with torch.no_grad():
|
| 45 |
+
outputs = model(**inputs)
|
| 46 |
+
predictions = outputs.logits
|
| 47 |
+
predicted_token_ids = torch.topk(predictions[0, mask_token_index], 5).indices.tolist()
|
| 48 |
+
predicted_tokens = [tokenizer.decode([id]).strip() for id in predicted_token_ids]
|
| 49 |
+
|
| 50 |
+
return {
|
| 51 |
+
"original_text": text,
|
| 52 |
+
"masked_text": tokenizer.decode(input_ids[0], skip_special_tokens=True),
|
| 53 |
+
"predicted_tokens": predicted_tokens,
|
| 54 |
+
"true_token": tokenizer.decode([labels[0, mask_token_index]], skip_special_tokens=True)
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 58 |
+
print(f"Общее количество параметров: {total_params}")
|
| 59 |
+
|
| 60 |
+
num_examples = 1
|
| 61 |
+
for i in range(num_examples):
|
| 62 |
+
result = predict_masked_text(dataset[i])
|
| 63 |
+
print(f"Оригинал: {result['original_text']}")
|
| 64 |
+
print(f"Замаскированный: {result['masked_text']}")
|
| 65 |
+
print(f"Предсказание: {result['predicted_tokens']}")
|
| 66 |
+
print(f"Истина: {result['true_token']}")
|