Darkester commited on
Commit
b61b1b5
·
verified ·
1 Parent(s): 87169f1

Upload 2 files

Browse files
Files changed (2) hide show
  1. ebert_model.py +40 -0
  2. proverka.py +66 -0
ebert_model.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertConfig, BertModel
2
+ import torch.nn as nn
3
+
4
+ class EBertConfig(BertConfig):
5
+ model_type = "ebert"
6
+ def __init__(self, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.adapter_size = kwargs.pop('adapter_size', None)
9
+
10
+ class EBertModel(BertModel):
11
+ config_class = EBertConfig
12
+
13
+ def __init__(self, config: EBertConfig):
14
+ super().__init__(config)
15
+ if config.adapter_size:
16
+ self.adapters = nn.ModuleList([
17
+ nn.Sequential(
18
+ nn.Linear(config.hidden_size, config.adapter_size),
19
+ nn.ReLU(),
20
+ nn.Linear(config.adapter_size, config.hidden_size),
21
+ )
22
+ for _ in range(config.num_hidden_layers)
23
+ ])
24
+ else:
25
+ self.adapters = None
26
+
27
+ def forward(self, *args, **kwargs):
28
+ outputs = super().forward(*args, **kwargs)
29
+ sequence_output = outputs.last_hidden_state
30
+
31
+ if self.adapters is not None:
32
+ for adapter in self.adapters:
33
+ sequence_output = sequence_output + adapter(sequence_output)
34
+
35
+ return outputs.__class__(
36
+ last_hidden_state=sequence_output,
37
+ pooler_output=outputs.pooler_output,
38
+ hidden_states=outputs.hidden_states,
39
+ attentions=outputs.attentions,
40
+ )
proverka.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizerFast, BertConfig
2
+ import torch
3
+ from datasets import load_dataset
4
+ from ebert_model import EBertConfig, EBertModel
5
+ from transformers import BertForMaskedLM
6
+ from safetensors.torch import load_file
7
+ import os
8
+
9
+ model_path = "./ebert"
10
+ tokenizer = BertTokenizerFast.from_pretrained(model_path)
11
+
12
+ config = EBertConfig.from_pretrained(model_path)
13
+ model = BertForMaskedLM(config)
14
+ model.bert = EBertModel(config)
15
+
16
+ weights_path = f"{model_path}/model.safetensors"
17
+ if os.path.exists(weights_path):
18
+ state_dict = load_file(weights_path)
19
+ model.load_state_dict(state_dict, strict=False)
20
+ else:
21
+ raise FileNotFoundError(f"Файл {weights_path} не найден.")
22
+
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ model.to(device)
25
+ model.eval()
26
+ dataset = load_dataset("Expotion/russian-facts-qa", split="train")
27
+
28
+ def predict_masked_text(example):
29
+ text = f"{example['q'].strip()} {example['a'].strip()}"
30
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding="max_length")
31
+ inputs = {k: v.to(device) for k, v in inputs.items()}
32
+
33
+ input_ids = inputs["input_ids"].clone()
34
+ labels = input_ids.clone()
35
+ mask_token_index = 2
36
+ if input_ids.size(1) <= mask_token_index:
37
+ return {
38
+ "original_text": text,
39
+ "masked_text": "Слишком короткий текст",
40
+ "predicted_tokens": [],
41
+ "true_token": ""
42
+ }
43
+ input_ids[0, mask_token_index] = tokenizer.mask_token_id
44
+ with torch.no_grad():
45
+ outputs = model(**inputs)
46
+ predictions = outputs.logits
47
+ predicted_token_ids = torch.topk(predictions[0, mask_token_index], 5).indices.tolist()
48
+ predicted_tokens = [tokenizer.decode([id]).strip() for id in predicted_token_ids]
49
+
50
+ return {
51
+ "original_text": text,
52
+ "masked_text": tokenizer.decode(input_ids[0], skip_special_tokens=True),
53
+ "predicted_tokens": predicted_tokens,
54
+ "true_token": tokenizer.decode([labels[0, mask_token_index]], skip_special_tokens=True)
55
+ }
56
+
57
+ total_params = sum(p.numel() for p in model.parameters())
58
+ print(f"Общее количество параметров: {total_params}")
59
+
60
+ num_examples = 1
61
+ for i in range(num_examples):
62
+ result = predict_masked_text(dataset[i])
63
+ print(f"Оригинал: {result['original_text']}")
64
+ print(f"Замаскированный: {result['masked_text']}")
65
+ print(f"Предсказание: {result['predicted_tokens']}")
66
+ print(f"Истина: {result['true_token']}")