dejanseo commited on
Commit
26536bf
·
verified ·
1 Parent(s): e2386c2

Upload 8 files

Browse files
Files changed (8) hide show
  1. app.py +406 -0
  2. config.json +22 -0
  3. generator_config.json +22 -0
  4. pytorch_model.bin +3 -0
  5. spm.model +3 -0
  6. tokenizer_config.json +4 -0
  7. train.py +698 -0
  8. training_config.json +29 -0
app.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import DebertaV2Model, DebertaV2TokenizerFast, DebertaV2Config, AutoTokenizer
5
+ from pathlib import Path
6
+ import numpy as np
7
+ import json
8
+ import logging
9
+ from dataclasses import dataclass
10
+ from typing import Optional, Dict, List, Tuple
11
+ from tqdm import tqdm
12
+ from skimage.filters import threshold_otsu
13
+
14
+ # ----------------------------------
15
+ # Logging
16
+ # ----------------------------------
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # ----------------------------------
21
+ # Config / Model
22
+ # ----------------------------------
23
+
24
+ @dataclass
25
+ class TrainingConfig:
26
+ """Training configuration for link token classification"""
27
+ model_name: str = "microsoft/deberta-v3-large"
28
+ num_labels: int = 2 # 0: not link, 1: link token
29
+
30
+ # Inference windowing
31
+ max_length: int = 512
32
+ doc_stride: int = 128 # match _prep.py for consistent windowing
33
+
34
+ # Train-only placeholders
35
+ train_file: str = ""
36
+ val_file: str = ""
37
+ batch_size: int = 1
38
+ gradient_accumulation_steps: int = 1
39
+ num_epochs: int = 1
40
+ learning_rate: float = 1e-5
41
+ warmup_ratio: float = 0.1
42
+ weight_decay: float = 0.01
43
+ max_grad_norm: float = 1.0
44
+ label_smoothing: float = 0.0
45
+
46
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
47
+ num_workers: int = 0
48
+ bf16: bool = False
49
+ seed: int = 42
50
+
51
+ logging_steps: int = 1
52
+ eval_steps: int = 100
53
+ save_steps: int = 100
54
+ output_dir: str = "./deberta_link_output" # model is loaded from here
55
+
56
+ wandb_project: str = ""
57
+ wandb_name: str = ""
58
+
59
+ patience: int = 2
60
+ min_delta: float = 0.0001
61
+
62
+
63
+ class DeBERTaForTokenClassification(nn.Module):
64
+ """DeBERTa model for token classification"""
65
+
66
+ def __init__(self, model_name: str, num_labels: int, dropout_rate: float = 0.1):
67
+ super().__init__()
68
+ self.config = DebertaV2Config.from_pretrained(model_name)
69
+ self.deberta = DebertaV2Model.from_pretrained(model_name)
70
+ self.dropout = nn.Dropout(dropout_rate)
71
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
72
+ nn.init.xavier_uniform_(self.classifier.weight)
73
+ nn.init.zeros_(self.classifier.bias)
74
+
75
+ def forward(
76
+ self,
77
+ input_ids: torch.Tensor,
78
+ attention_mask: torch.Tensor,
79
+ labels: Optional[torch.Tensor] = None
80
+ ) -> Dict[str, torch.Tensor]:
81
+ outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
82
+ sequence_output = self.dropout(outputs.last_hidden_state)
83
+ logits = self.classifier(sequence_output)
84
+ return {'loss': None, 'logits': logits}
85
+
86
+ # ----------------------------------
87
+ # Load model/tokenizer (robust)
88
+ # ----------------------------------
89
+
90
+ @st.cache_resource
91
+ def load_model():
92
+ """Loads pre-trained model and tokenizer. Handles raw state_dict and wrapped checkpoints."""
93
+ config = TrainingConfig()
94
+ final_dir = Path(config.output_dir) / "final_model"
95
+ model_path = final_dir / "pytorch_model.bin"
96
+
97
+ if not model_path.exists():
98
+ st.error(f"Model checkpoint not found at {model_path}.")
99
+ st.stop()
100
+
101
+ logger.info(f"Loading model from {model_path}...")
102
+ model = DeBERTaForTokenClassification(config.model_name, config.num_labels)
103
+
104
+ # Load checkpoint robustly
105
+ try:
106
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)
107
+ except TypeError:
108
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
109
+
110
+ # Determine state_dict
111
+ state_dict = None
112
+ if isinstance(checkpoint, dict):
113
+ # Case A: raw state_dict (keys -> tensors)
114
+ if checkpoint and all(isinstance(v, torch.Tensor) for v in checkpoint.values()):
115
+ state_dict = checkpoint
116
+ logger.info("Detected raw state_dict checkpoint.")
117
+ # Case B: wrapped dicts
118
+ elif 'model_state_dict' in checkpoint and isinstance(checkpoint['model_state_dict'], dict):
119
+ state_dict = checkpoint['model_state_dict']
120
+ logger.info("Detected 'model_state_dict' in checkpoint.")
121
+ elif 'state_dict' in checkpoint and isinstance(checkpoint['state_dict'], dict):
122
+ state_dict = checkpoint['state_dict']
123
+ logger.info("Detected 'state_dict' in checkpoint.")
124
+ else:
125
+ raise KeyError(f"Unrecognized checkpoint format keys: {list(checkpoint.keys())}")
126
+ else:
127
+ raise TypeError(f"Unexpected checkpoint type: {type(checkpoint)}")
128
+
129
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
130
+ if missing:
131
+ logger.warning(f"Missing keys: {missing}")
132
+ if unexpected:
133
+ logger.warning(f"Unexpected keys: {unexpected}")
134
+
135
+ model.to(config.device)
136
+ model.eval()
137
+
138
+ logger.info(f"Loading tokenizer {config.model_name}...")
139
+ tokenizer = DebertaV2TokenizerFast.from_pretrained(config.model_name)
140
+ logger.info("Tokenizer loaded.")
141
+
142
+ return model, tokenizer, config.device, config.max_length, config.doc_stride
143
+
144
+ model, tokenizer, device, MAX_LENGTH, DOC_STRIDE = load_model()
145
+
146
+ # ----------------------------------
147
+ # Inference helpers
148
+ # ----------------------------------
149
+
150
+ def windowize_inference(
151
+ plain_text: str,
152
+ tokenizer: AutoTokenizer,
153
+ max_length: int,
154
+ doc_stride: int
155
+ ) -> List[Dict]:
156
+ """Slice long text into overlapping windows for inference."""
157
+ specials = tokenizer.num_special_tokens_to_add(pair=False)
158
+ cap = max_length - specials
159
+ if cap <= 0:
160
+ raise ValueError(f"max_length too small; specials={specials}")
161
+
162
+ full_encoding = tokenizer(
163
+ plain_text,
164
+ add_special_tokens=False,
165
+ return_offsets_mapping=True,
166
+ return_attention_mask=False,
167
+ return_token_type_ids=False,
168
+ truncation=False,
169
+ )
170
+ input_ids_no_special = full_encoding["input_ids"]
171
+ offsets_no_special = full_encoding["offset_mapping"]
172
+
173
+ temp_encoding_for_word_ids = tokenizer(
174
+ plain_text, return_offsets_mapping=True, truncation=False, padding=False
175
+ )
176
+ full_word_ids = temp_encoding_for_word_ids.word_ids(batch_index=0)
177
+
178
+ windows_data = []
179
+ step = max(cap - doc_stride, 1)
180
+ start_token_idx = 0
181
+ total_tokens_no_special = len(input_ids_no_special)
182
+
183
+ while start_token_idx < total_tokens_no_special:
184
+ end_token_idx = min(start_token_idx + cap, total_tokens_no_special)
185
+
186
+ ids_slice_no_special = input_ids_no_special[start_token_idx:end_token_idx]
187
+ offsets_slice_no_special = offsets_no_special[start_token_idx:end_token_idx]
188
+ word_ids_slice = full_word_ids[start_token_idx:end_token_idx]
189
+
190
+ input_ids_with_special = tokenizer.build_inputs_with_special_tokens(ids_slice_no_special)
191
+ attention_mask_with_special = [1] * len(input_ids_with_special)
192
+
193
+ padding_length = max_length - len(input_ids_with_special)
194
+ if padding_length > 0:
195
+ input_ids_with_special.extend([tokenizer.pad_token_id] * padding_length)
196
+ attention_mask_with_special.extend([0] * padding_length)
197
+
198
+ window_offset_mapping = offsets_slice_no_special[:]
199
+ window_word_ids = word_ids_slice[:]
200
+
201
+ if tokenizer.cls_token_id is not None:
202
+ window_offset_mapping.insert(0, (0, 0))
203
+ window_word_ids.insert(0, None)
204
+ if tokenizer.sep_token_id is not None and len(window_offset_mapping) < max_length:
205
+ window_offset_mapping.append((0, 0))
206
+ window_word_ids.append(None)
207
+
208
+ while len(window_offset_mapping) < max_length:
209
+ window_offset_mapping.append((0, 0))
210
+ window_word_ids.append(None)
211
+
212
+ windows_data.append({
213
+ "input_ids": torch.tensor(input_ids_with_special, dtype=torch.long),
214
+ "attention_mask": torch.tensor(attention_mask_with_special, dtype=torch.long),
215
+ "word_ids": window_word_ids,
216
+ "offset_mapping": window_offset_mapping,
217
+ })
218
+
219
+ if end_token_idx == total_tokens_no_special:
220
+ break
221
+ start_token_idx += step
222
+
223
+ return windows_data
224
+
225
+
226
+ def classify_text(
227
+ text: str,
228
+ otsu_mode: str,
229
+ prediction_threshold_override: Optional[float] = None
230
+ ) -> Tuple[str, Optional[str], Optional[float]]:
231
+ """Classify link tokens with windowing. Returns (html, warning, threshold%)."""
232
+ if not text.strip():
233
+ return "", None, None
234
+
235
+ windows = windowize_inference(text, tokenizer, MAX_LENGTH, DOC_STRIDE)
236
+ if not windows:
237
+ return "", "Could not generate any windows for processing.", None
238
+
239
+ char_link_probabilities = np.zeros(len(text), dtype=np.float32)
240
+ char_covered = np.zeros(len(text), dtype=bool)
241
+ all_content_token_probs = []
242
+
243
+ with torch.no_grad():
244
+ for window in tqdm(windows, desc="Processing windows"):
245
+ inputs = {
246
+ 'input_ids': window['input_ids'].unsqueeze(0).to(device),
247
+ 'attention_mask': window['attention_mask'].unsqueeze(0).to(device)
248
+ }
249
+ outputs = model(**inputs)
250
+ logits = outputs['logits'].squeeze(0)
251
+ probabilities = torch.softmax(logits, dim=-1)
252
+ link_probs_for_window_tokens = probabilities[:, 1].cpu().numpy()
253
+
254
+ for i, (offset_start, offset_end) in enumerate(window['offset_mapping']):
255
+ if window['word_ids'][i] is not None and offset_start < offset_end:
256
+ char_link_probabilities[offset_start:offset_end] = np.maximum(
257
+ char_link_probabilities[offset_start:offset_end],
258
+ link_probs_for_window_tokens[i]
259
+ )
260
+ char_covered[offset_start:offset_end] = True
261
+ all_content_token_probs.append(link_probs_for_window_tokens[i])
262
+
263
+ # Threshold selection (Otsu or manual)
264
+ determined_threshold_float = None
265
+ determined_threshold_for_display = None # 0-100%
266
+
267
+ if prediction_threshold_override is not None:
268
+ determined_threshold_float = prediction_threshold_override / 100.0
269
+ determined_threshold_for_display = prediction_threshold_override
270
+ else:
271
+ if len(all_content_token_probs) > 1:
272
+ try:
273
+ otsu_base_threshold = threshold_otsu(np.array(all_content_token_probs))
274
+ conservative_delta = 0.1 # stricter
275
+ generous_delta = 0.1 # more lenient
276
+ if otsu_mode == 'conservative':
277
+ determined_threshold_float = otsu_base_threshold + conservative_delta
278
+ elif otsu_mode == 'generous':
279
+ determined_threshold_float = otsu_base_threshold - generous_delta
280
+ else:
281
+ determined_threshold_float = otsu_base_threshold
282
+ determined_threshold_float = max(0.0, min(1.0, determined_threshold_float))
283
+ determined_threshold_for_display = determined_threshold_float * 100
284
+ except ValueError:
285
+ logger.warning("Otsu failed; defaulting to 0.5.")
286
+ determined_threshold_float = 0.5
287
+ determined_threshold_for_display = 50.0
288
+ else:
289
+ logger.warning("Insufficient tokens for Otsu; defaulting to 0.5.")
290
+ determined_threshold_float = 0.5
291
+ determined_threshold_for_display = 50.0
292
+
293
+ final_threshold = determined_threshold_float
294
+
295
+ # Word-level aggregation
296
+ full_text_encoding = tokenizer(text, return_offsets_mapping=True, truncation=False, padding=False)
297
+ full_word_ids = full_text_encoding.word_ids(batch_index=0)
298
+ full_offset_mapping = full_text_encoding['offset_mapping']
299
+
300
+ word_prob_map: Dict[int, List[float]] = {}
301
+ word_char_spans: Dict[int, List[int]] = {}
302
+
303
+ for i, word_id in enumerate(full_word_ids):
304
+ if word_id is not None:
305
+ start_char, end_char = full_offset_mapping[i]
306
+ if start_char < end_char and np.any(char_covered[start_char:end_char]):
307
+ if word_id not in word_prob_map:
308
+ word_prob_map[word_id] = []
309
+ word_char_spans[word_id] = [start_char, end_char]
310
+ else:
311
+ word_char_spans[word_id][0] = min(word_char_spans[word_id][0], start_char)
312
+ word_char_spans[word_id][1] = max(word_char_spans[word_id][1], end_char)
313
+
314
+ token_span_probs = char_link_probabilities[start_char:end_char]
315
+ word_prob_map[word_id].append(np.max(token_span_probs) if token_span_probs.size > 0 else 0.0)
316
+ elif word_id not in word_prob_map:
317
+ word_prob_map[word_id] = [0.0]
318
+ word_char_spans[word_id] = list(full_offset_mapping[i])
319
+
320
+ words_to_highlight_status: Dict[int, bool] = {}
321
+ for word_id, probs in word_prob_map.items():
322
+ max_word_prob = np.max(probs) if probs else 0.0
323
+ words_to_highlight_status[word_id] = (max_word_prob >= final_threshold)
324
+
325
+ # Reconstruct HTML with highlights
326
+ html_output_parts: List[str] = []
327
+ current_char_idx = 0
328
+ sorted_word_ids = sorted(word_char_spans.keys(), key=lambda k: word_char_spans[k][0])
329
+
330
+ for word_id in sorted_word_ids:
331
+ start_char, end_char = word_char_spans[word_id]
332
+ if start_char > current_char_idx:
333
+ html_output_parts.append(text[current_char_idx:start_char])
334
+
335
+ word_text = text[start_char:end_char]
336
+ if words_to_highlight_status.get(word_id, False):
337
+ html_output_parts.append(
338
+ "<span style='background-color: #D4EDDA; color: #155724; padding: 0.1em 0.2em; border-radius: 0.2em;'>"
339
+ + word_text +
340
+ "</span>"
341
+ )
342
+ else:
343
+ html_output_parts.append(word_text)
344
+ current_char_idx = end_char
345
+
346
+ if current_char_idx < len(text):
347
+ html_output_parts.append(text[current_char_idx:])
348
+
349
+ return "".join(html_output_parts), None, determined_threshold_for_display
350
+
351
+ # ----------------------------------
352
+ # Streamlit UI
353
+ # ----------------------------------
354
+
355
+ st.set_page_config(layout="wide", page_title="LinkBERT by DEJAN AI")
356
+ st.title("LinkBERT")
357
+
358
+ user_input = st.text_area(
359
+ "Paste your text here:",
360
+ "DEJAN AI is the world's leading AI SEO agency.",
361
+ height=200
362
+ )
363
+
364
+ with st.expander('Settings'):
365
+ auto_threshold_enabled = st.checkbox(
366
+ "Automagic",
367
+ value=True,
368
+ help="Uncheck to set manual threshold value for link prediction."
369
+ )
370
+
371
+ otsu_mode_options = ['Conservative', 'Standard', 'Generous']
372
+ selected_otsu_mode = 'Standard'
373
+ if auto_threshold_enabled:
374
+ selected_otsu_mode = st.radio(
375
+ "Generosity:",
376
+ otsu_mode_options,
377
+ index=1,
378
+ help="Generous suggests more links; conservative suggests fewer."
379
+ )
380
+
381
+ prediction_threshold_manual = 50.0
382
+ if not auto_threshold_enabled:
383
+ prediction_threshold_manual = st.slider(
384
+ "Manual Link Probability Threshold (%)",
385
+ min_value=0,
386
+ max_value=100,
387
+ value=50,
388
+ step=1,
389
+ help="Minimum probability to classify a token as a link when Automagic is off."
390
+ )
391
+
392
+ if st.button("Classify Text"):
393
+ if not user_input.strip():
394
+ st.warning("Please enter some text to classify.")
395
+ else:
396
+ threshold_to_pass = None if auto_threshold_enabled else prediction_threshold_manual
397
+ highlighted_html, warning_message, determined_threshold_for_display = classify_text(
398
+ user_input,
399
+ selected_otsu_mode.lower(),
400
+ threshold_to_pass
401
+ )
402
+ if warning_message:
403
+ st.warning(warning_message)
404
+ if determined_threshold_for_display is not None and auto_threshold_enabled:
405
+ st.info(f"Auto threshold: {determined_threshold_for_display:.1f}% ({selected_otsu_mode})")
406
+ st.markdown(highlighted_html, unsafe_allow_html=True)
config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "deberta-v2",
3
+ "attention_probs_dropout_prob": 0.1,
4
+ "hidden_act": "gelu",
5
+ "hidden_dropout_prob": 0.1,
6
+ "hidden_size": 1024,
7
+ "initializer_range": 0.02,
8
+ "intermediate_size": 4096,
9
+ "max_position_embeddings": 512,
10
+ "relative_attention": true,
11
+ "position_buckets": 256,
12
+ "norm_rel_ebd": "layer_norm",
13
+ "share_att_key": true,
14
+ "pos_att_type": "p2c|c2p",
15
+ "layer_norm_eps": 1e-7,
16
+ "max_relative_positions": -1,
17
+ "position_biased_input": false,
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "type_vocab_size": 0,
21
+ "vocab_size": 128100
22
+ }
generator_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "deberta-v2",
3
+ "attention_probs_dropout_prob": 0.1,
4
+ "hidden_act": "gelu",
5
+ "hidden_dropout_prob": 0.1,
6
+ "hidden_size": 1024,
7
+ "initializer_range": 0.02,
8
+ "intermediate_size": 4096,
9
+ "max_position_embeddings": 512,
10
+ "relative_attention": true,
11
+ "position_buckets": 256,
12
+ "norm_rel_ebd": "layer_norm",
13
+ "share_att_key": true,
14
+ "pos_att_type": "p2c|c2p",
15
+ "layer_norm_eps": 1e-7,
16
+ "max_relative_positions": -1,
17
+ "position_biased_input": false,
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 12,
20
+ "type_vocab_size": 0,
21
+ "vocab_size": 128100
22
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf39ce70d265128366245b987d930b293445bafcc323a3f1d7cc6f8594139c14
3
+ size 1736224579
spm.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c679fbf93643d19aab7ee10c0b99e460bdbc02fedf34b92b05af343b4af586fd
3
+ size 2464616
tokenizer_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "do_lower_case": false,
3
+ "vocab_type": "spm"
4
+ }
train.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import shutil
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from transformers import (
7
+ DebertaV2Model,
8
+ DebertaV2TokenizerFast,
9
+ DebertaV2Config,
10
+ get_linear_schedule_with_warmup,
11
+ set_seed
12
+ )
13
+ from torch.cuda.amp import autocast
14
+ from tqdm import tqdm
15
+ import numpy as np
16
+ from pathlib import Path
17
+ import logging
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Dict, List, Tuple
20
+ import wandb
21
+ from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
22
+ import functools # Import functools for partial
23
+ import re
24
+
25
+ # Setup logging
26
+ logging.basicConfig(
27
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
28
+ datefmt='%m/%d/%Y %H:%M:%S',
29
+ level=logging.INFO
30
+ )
31
+ logger = logging.getLogger(__name__)
32
+
33
+ @dataclass
34
+ class TrainingConfig:
35
+ """Training configuration for link token classification"""
36
+ # Model
37
+ model_name: str = "microsoft/deberta-v3-large"
38
+ num_labels: int = 2 # 0: not link, 1: link token
39
+
40
+ # Data
41
+ train_file: str = "train_windows.jsonl"
42
+ val_file: str = "val_windows.jsonl"
43
+ max_length: int = 512 # This is the crucial fixed length for padding
44
+
45
+ # Training
46
+ batch_size: int = 8
47
+ gradient_accumulation_steps: int = 8
48
+ num_epochs: int = 3
49
+ learning_rate: float = 1e-6
50
+ warmup_ratio: float = 0.1
51
+ weight_decay: float = 0.01
52
+ max_grad_norm: float = 1.0
53
+ label_smoothing: float = 0.0 # Not currently used in CrossEntropyLoss
54
+
55
+ # System
56
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
57
+ num_workers: int = 0 # Set to 0 for Windows to avoid multiprocessing issues
58
+ seed: int = 42
59
+ bf16: bool = True # Using BF16 for RTX 4090
60
+
61
+ # Logging
62
+ logging_steps: int = 1 # Log every step to wandb
63
+ eval_steps: int = 5000
64
+ save_steps: int = 10000
65
+ output_dir: str = "./deberta_link_output"
66
+
67
+ # WandB
68
+ wandb_project: str = "deberta-link-classification"
69
+ wandb_name: str = "deberta-v3-large-link-tokens"
70
+
71
+ # Early stopping
72
+ patience: int = 2
73
+ min_delta: float = 0.0001
74
+
75
+ # Checkpoint retention (Scope A: count all subdirs except 'final_model')
76
+ max_checkpoints: int = 5
77
+ protect_latest_epoch_step: bool = True # Always keep latest best_model_epoch_* and best_model_step_*
78
+
79
+
80
+ class LinkTokenDataset(Dataset):
81
+ """Dataset for link token classification"""
82
+
83
+ def __init__(self, file_path: str, max_samples: Optional[int] = None):
84
+ self.data = []
85
+
86
+ logger.info(f"Loading data from {file_path}")
87
+ seq_lengths = []
88
+
89
+ with open(file_path, 'r') as f:
90
+ for i, line in enumerate(f):
91
+ if max_samples and i >= max_samples:
92
+ break
93
+ sample = json.loads(line)
94
+
95
+ seq_len = len(sample['input_ids'])
96
+ seq_lengths.append(seq_len)
97
+
98
+ # Convert to tensors
99
+ sample['input_ids'] = torch.tensor(sample['input_ids'], dtype=torch.long)
100
+ sample['attention_mask'] = torch.tensor(sample['attention_mask'], dtype=torch.long)
101
+ sample['labels'] = torch.tensor(sample['labels'], dtype=torch.long)
102
+
103
+ self.data.append(sample)
104
+
105
+ logger.info(f"Loaded {len(self.data)} samples")
106
+ logger.info(f"Sequence lengths - Min: {min(seq_lengths)}, Max: {max(seq_lengths)}, Avg: {np.mean(seq_lengths):.1f}")
107
+
108
+ # Calculate class weights for imbalanced data (for logging info)
109
+ total_labels = []
110
+ for s in self.data:
111
+ # Only count non-padded positions (where labels are not -100)
112
+ valid_labels = s['labels'][s['labels'] != -100]
113
+ total_labels.append(valid_labels)
114
+
115
+ # Ensure total_labels is not empty before concatenating
116
+ if total_labels:
117
+ total_labels = torch.cat(total_labels)
118
+ num_link_tokens = (total_labels == 1).sum().item()
119
+ num_non_link = (total_labels == 0).sum().item()
120
+
121
+ logger.info(f"Label distribution - Non-link: {num_non_link}, Link: {num_link_tokens}")
122
+ if (num_link_tokens + num_non_link) > 0:
123
+ logger.info(f"Link token ratio: {num_link_tokens / (num_link_tokens + num_non_link):.4%}")
124
+ else:
125
+ logger.info("No valid labels found in the dataset.")
126
+
127
+ def __len__(self):
128
+ return len(self.data)
129
+
130
+ def __getitem__(self, idx):
131
+ return self.data[idx]
132
+
133
+
134
+ def collate_fn(batch: List[Dict], max_seq_length: int) -> Dict[str, torch.Tensor]:
135
+ """
136
+ Custom collate function for batching with padding to a fixed max_seq_length.
137
+
138
+ Args:
139
+ batch (List[Dict]): A list of samples from the dataset.
140
+ max_seq_length (int): The maximum sequence length to pad all samples to.
141
+
142
+ Returns:
143
+ Dict[str, torch.Tensor]: A dictionary containing stacked and padded tensors.
144
+ """
145
+
146
+ input_ids = []
147
+ attention_mask = []
148
+ labels = []
149
+
150
+ for x in batch:
151
+ seq_len = len(x['input_ids'])
152
+
153
+ # Truncate if sequence is longer than max_seq_length (shouldn't happen with preprocessed data)
154
+ if seq_len > max_seq_length:
155
+ x['input_ids'] = x['input_ids'][:max_seq_length]
156
+ x['attention_mask'] = x['attention_mask'][:max_seq_length]
157
+ x['labels'] = x['labels'][:max_seq_length]
158
+ seq_len = max_seq_length
159
+
160
+ # Pad sequences to the global max_seq_length
161
+ padding_len = max_seq_length - seq_len
162
+
163
+ # Pad input_ids with 0 (typically the pad token id)
164
+ padded_input = torch.cat([
165
+ x['input_ids'],
166
+ torch.zeros(padding_len, dtype=torch.long)
167
+ ])
168
+
169
+ # Pad attention_mask with 0 (ignore padded tokens)
170
+ padded_mask = torch.cat([
171
+ x['attention_mask'],
172
+ torch.zeros(padding_len, dtype=torch.long)
173
+ ])
174
+
175
+ # Pad labels with -100 (ignored in loss calculation)
176
+ padded_labels = torch.cat([
177
+ x['labels'],
178
+ torch.full((padding_len,), -100, dtype=torch.long)
179
+ ])
180
+
181
+ input_ids.append(padded_input)
182
+ attention_mask.append(padded_mask)
183
+ labels.append(padded_labels)
184
+
185
+ return {
186
+ 'input_ids': torch.stack(input_ids),
187
+ 'attention_mask': torch.stack(attention_mask),
188
+ 'labels': torch.stack(labels)
189
+ }
190
+
191
+
192
+ class DeBERTaForTokenClassification(nn.Module):
193
+ """DeBERTa model for token classification"""
194
+
195
+ def __init__(self, model_name: str, num_labels: int, dropout_rate: float = 0.1):
196
+ super().__init__()
197
+
198
+ self.config = DebertaV2Config.from_pretrained(model_name)
199
+ self.deberta = DebertaV2Model.from_pretrained(model_name)
200
+
201
+ self.dropout = nn.Dropout(dropout_rate)
202
+ self.classifier = nn.Linear(self.config.hidden_size, num_labels)
203
+
204
+ # Initialize classifier weights
205
+ nn.init.xavier_uniform_(self.classifier.weight)
206
+ nn.init.zeros_(self.classifier.bias)
207
+
208
+ def forward(
209
+ self,
210
+ input_ids: torch.Tensor,
211
+ attention_mask: torch.Tensor,
212
+ labels: Optional[torch.Tensor] = None
213
+ ) -> Dict[str, torch.Tensor]:
214
+
215
+ outputs = self.deberta(
216
+ input_ids=input_ids,
217
+ attention_mask=attention_mask
218
+ )
219
+
220
+ sequence_output = outputs.last_hidden_state
221
+ sequence_output = self.dropout(sequence_output)
222
+ logits = self.classifier(sequence_output)
223
+
224
+ loss = None
225
+ if labels is not None:
226
+ # Calculate class weights for imbalanced dataset
227
+ # Link tokens are ~3.88% of data, so weight them ~25x more
228
+ # Ensure weight tensor is on the correct device
229
+ weight = torch.tensor([1.0, 25.0]).to(logits.device)
230
+
231
+ loss_fct = nn.CrossEntropyLoss(weight=weight, ignore_index=-100)
232
+ # Reshape logits to (batch_size * sequence_length, num_labels)
233
+ # Reshape labels to (batch_size * sequence_length)
234
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
235
+
236
+ return {
237
+ 'loss': loss,
238
+ 'logits': logits
239
+ }
240
+
241
+
242
+ def compute_metrics(predictions: np.ndarray, labels: np.ndarray, mask: np.ndarray) -> Dict[str, float]:
243
+ """Compute metrics for token classification"""
244
+ # Flatten and remove padding
245
+ # Only consider positions where attention_mask is 1 AND labels are not -100
246
+ # The -100 in labels already implies an ignored position, so we can primarily filter by that.
247
+
248
+ # Flatten all predictions, labels, and masks
249
+ predictions_flat = predictions.flatten()
250
+ labels_flat = labels.flatten()
251
+ mask_flat = mask.flatten()
252
+
253
+ # Create a combined filter for valid tokens (not padding, not -100 label)
254
+ valid_indices = (labels_flat != -100) & (mask_flat == 1)
255
+
256
+ preds_filtered = predictions_flat[valid_indices]
257
+ labels_filtered = labels_flat[valid_indices]
258
+
259
+ # Handle cases where no valid tokens are present
260
+ if len(labels_filtered) == 0:
261
+ return {
262
+ 'accuracy': 0.0,
263
+ 'precision': 0.0,
264
+ 'recall': 0.0,
265
+ 'f1': 0.0,
266
+ 'f1_non_link': 0.0,
267
+ 'f1_link': 0.0,
268
+ 'precision_link': 0.0,
269
+ 'recall_link': 0.0,
270
+ 'num_valid_tokens': 0
271
+ }
272
+
273
+ # Calculate metrics
274
+ accuracy = accuracy_score(labels_filtered, preds_filtered)
275
+
276
+ precision, recall, f1, support = precision_recall_fscore_support(
277
+ labels_filtered, preds_filtered, average='binary', pos_label=1, zero_division=0
278
+ )
279
+
280
+ # Per-class metrics
281
+ unique_labels_in_data = np.unique(labels_filtered)
282
+
283
+ precision_per_class = [0.0, 0.0]
284
+ recall_per_class = [0.0, 0.0]
285
+ f1_per_class = [0.0, 0.0]
286
+
287
+ # Class 0 (non-link)
288
+ if 0 in unique_labels_in_data:
289
+ p0, r0, f0, _ = precision_recall_fscore_support(
290
+ labels_filtered, preds_filtered, labels=[0], average='binary', pos_label=0, zero_division=0
291
+ )
292
+ precision_per_class[0] = p0
293
+ recall_per_class[0] = r0
294
+ f1_per_class[0] = f0
295
+
296
+ # Class 1 (link)
297
+ if 1 in unique_labels_in_data:
298
+ p1, r1, f1_1, _ = precision_recall_fscore_support(
299
+ labels_filtered, preds_filtered, labels=[1], average='binary', pos_label=1, zero_division=0
300
+ )
301
+ precision_per_class[1] = p1
302
+ recall_per_class[1] = r1
303
+ f1_per_class[1] = f1_1
304
+
305
+ return {
306
+ 'accuracy': accuracy,
307
+ 'precision': precision,
308
+ 'recall': recall,
309
+ 'f1': f1,
310
+ 'f1_non_link': f1_per_class[0],
311
+ 'f1_link': f1_per_class[1],
312
+ 'precision_link': precision_per_class[1],
313
+ 'recall_link': recall_per_class[1],
314
+ 'num_valid_tokens': len(labels_filtered)
315
+ }
316
+
317
+
318
+ class Trainer:
319
+ """Trainer class for DeBERTa token classification"""
320
+
321
+ def __init__(self, config: TrainingConfig):
322
+ self.config = config
323
+ set_seed(config.seed)
324
+
325
+ # Initialize wandb
326
+ wandb.init(
327
+ project=config.wandb_project,
328
+ name=config.wandb_name,
329
+ config=vars(config)
330
+ )
331
+
332
+ # Create output directory
333
+ Path(config.output_dir).mkdir(parents=True, exist_ok=True)
334
+
335
+ # Load datasets
336
+ self.train_dataset = LinkTokenDataset(config.train_file)
337
+ self.val_dataset = LinkTokenDataset(config.val_file)
338
+
339
+ # Create dataloaders
340
+ # Use functools.partial to pass the fixed max_length to collate_fn
341
+ self.train_loader = DataLoader(
342
+ self.train_dataset,
343
+ batch_size=config.batch_size,
344
+ shuffle=False,
345
+ num_workers=config.num_workers,
346
+ collate_fn=functools.partial(collate_fn, max_seq_length=config.max_length),
347
+ pin_memory=True
348
+ )
349
+
350
+ self.val_loader = DataLoader(
351
+ self.val_dataset,
352
+ batch_size=config.batch_size * 2, # Often larger batch size for validation
353
+ shuffle=False,
354
+ num_workers=config.num_workers,
355
+ collate_fn=functools.partial(collate_fn, max_seq_length=config.max_length),
356
+ pin_memory=True
357
+ )
358
+
359
+ # Initialize model
360
+ self.model = DeBERTaForTokenClassification(
361
+ config.model_name,
362
+ config.num_labels
363
+ ).to(config.device)
364
+
365
+ # Count parameters
366
+ total_params = sum(p.numel() for p in self.model.parameters())
367
+ trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
368
+ logger.info(f"Total parameters: {total_params:,}")
369
+ logger.info(f"Trainable parameters: {trainable_params:,}")
370
+
371
+ # Initialize optimizer
372
+ no_decay = ['bias', 'LayerNorm.weight']
373
+ optimizer_grouped_parameters = [
374
+ {
375
+ 'params': [p for n, p in self.model.named_parameters()
376
+ if not any(nd in n for nd in no_decay)],
377
+ 'weight_decay': config.weight_decay
378
+ },
379
+ {
380
+ 'params': [p for n, p in self.model.named_parameters()
381
+ if any(nd in n for nd in no_decay)],
382
+ 'weight_decay': 0.0
383
+ }
384
+ ]
385
+
386
+ self.optimizer = torch.optim.AdamW(
387
+ optimizer_grouped_parameters,
388
+ lr=config.learning_rate,
389
+ eps=1e-6
390
+ )
391
+
392
+ # Initialize scheduler
393
+ total_steps = len(self.train_loader) * config.num_epochs // config.gradient_accumulation_steps
394
+ warmup_steps = int(total_steps * config.warmup_ratio)
395
+
396
+ self.scheduler = get_linear_schedule_with_warmup(
397
+ self.optimizer,
398
+ num_warmup_steps=warmup_steps,
399
+ num_training_steps=total_steps
400
+ )
401
+
402
+ # Tracking variables
403
+ self.global_step = 0
404
+ self.best_val_loss = float('inf')
405
+ self.patience_counter = 0
406
+
407
+ def train_epoch(self, epoch: int) -> float:
408
+ """Train for one epoch"""
409
+ self.model.train()
410
+ total_loss = 0
411
+ progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch}")
412
+
413
+ # Flag to indicate if early stopping was triggered mid-epoch
414
+ early_stop_triggered = False
415
+
416
+ for step, batch in enumerate(progress_bar):
417
+ # Move batch to device
418
+ batch = {k: v.to(self.config.device) for k, v in batch.items()}
419
+
420
+ # Forward pass with BF16 mixed precision
421
+ if self.config.bf16:
422
+ with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
423
+ outputs = self.model(**batch)
424
+ loss = outputs['loss'] / self.config.gradient_accumulation_steps
425
+ else:
426
+ outputs = self.model(**batch)
427
+ loss = outputs['loss'] / self.config.gradient_accumulation_steps
428
+
429
+ # Check if loss is NaN or inf, and skip if it is
430
+ if torch.isnan(loss) or torch.isinf(loss):
431
+ logger.warning(f"NaN or Inf loss encountered at step {self.global_step}. Skipping backward pass.")
432
+ self.optimizer.zero_grad() # Clear gradients for current batch
433
+ continue # Skip this step
434
+
435
+ loss.backward()
436
+ total_loss += loss.item()
437
+
438
+ # Gradient accumulation
439
+ if (step + 1) % self.config.gradient_accumulation_steps == 0:
440
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
441
+ self.optimizer.step()
442
+ self.scheduler.step()
443
+ self.optimizer.zero_grad()
444
+ self.global_step += 1
445
+
446
+ # Logging - every step to wandb
447
+ if self.global_step % self.config.logging_steps == 0:
448
+ current_loss = loss.item() * self.config.gradient_accumulation_steps
449
+ wandb.log({
450
+ 'train/loss': current_loss,
451
+ 'train/learning_rate': self.scheduler.get_last_lr()[0],
452
+ 'train/global_step': self.global_step,
453
+ 'train/epoch': epoch
454
+ })
455
+ progress_bar.set_postfix({'loss': f'{current_loss:.4f}'})
456
+
457
+ # Evaluation
458
+ if self.global_step % self.config.eval_steps == 0:
459
+ eval_metrics = self.evaluate()
460
+ logger.info(f"Step {self.global_step} - Eval metrics: {eval_metrics}")
461
+
462
+ # Early stopping check based on validation loss
463
+ current_val_loss = eval_metrics['loss']
464
+ if current_val_loss < self.best_val_loss - self.config.min_delta:
465
+ self.best_val_loss = current_val_loss
466
+ self.patience_counter = 0
467
+ self.save_model(f"best_model_step_{self.global_step}")
468
+ logger.info(f"New best validation loss: {self.best_val_loss:.4f}")
469
+ else:
470
+ self.patience_counter += 1
471
+ logger.info(f"No improvement in validation loss. Patience: {self.patience_counter}/{self.config.patience}")
472
+ if self.patience_counter >= self.config.patience:
473
+ logger.info("Early stopping triggered mid-epoch!")
474
+ early_stop_triggered = True
475
+ break # Break from the inner loop (current epoch)
476
+
477
+ if early_stop_triggered:
478
+ break # Break from the outer loop (current epoch)
479
+
480
+ return total_loss / len(self.train_loader) if len(self.train_loader) > 0 else 0.0 # Return 0 if loader is empty
481
+
482
+ def evaluate(self) -> Dict[str, float]:
483
+ """Evaluate on validation set"""
484
+ self.model.eval()
485
+
486
+ all_predictions = []
487
+ all_labels = []
488
+ all_masks = []
489
+ total_loss = 0
490
+ num_batches = 0
491
+
492
+ with torch.no_grad():
493
+ for batch in tqdm(self.val_loader, desc="Evaluating"):
494
+ batch = {k: v.to(self.config.device) for k, v in batch.items()}
495
+
496
+ # Use BF16 for evaluation too
497
+ if self.config.bf16:
498
+ with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
499
+ outputs = self.model(**batch)
500
+ else:
501
+ outputs = self.model(**batch)
502
+
503
+ if outputs['loss'] is not None:
504
+ total_loss += outputs['loss'].item()
505
+ num_batches += 1
506
+
507
+ predictions = torch.argmax(outputs['logits'], dim=-1)
508
+
509
+ all_predictions.append(predictions.cpu().numpy())
510
+ all_labels.append(batch['labels'].cpu().numpy())
511
+ all_masks.append(batch['attention_mask'].cpu().numpy())
512
+
513
+ all_predictions = np.concatenate(all_predictions, axis=0)
514
+ all_labels = np.concatenate(all_labels, axis=0)
515
+ all_masks = np.concatenate(all_masks, axis=0)
516
+
517
+ # Compute metrics
518
+ metrics = compute_metrics(all_predictions, all_labels, all_masks)
519
+ metrics['loss'] = total_loss / num_batches if num_batches > 0 else 0.0
520
+
521
+ # Log to wandb
522
+ wandb.log({f'eval/{k}': v for k, v in metrics.items()}, step=self.global_step)
523
+
524
+ self.model.train() # Set model back to train mode after evaluation
525
+ return metrics
526
+
527
+ def _enforce_checkpoint_limit(self):
528
+ """
529
+ Enforce checkpoint retention:
530
+ - Count all subdirectories in output_dir except 'final_model'
531
+ - Keep at most config.max_checkpoints
532
+ - Delete oldest by modification time
533
+ - Always protect:
534
+ * 'final_model'
535
+ * latest 'best_model_epoch_*'
536
+ * latest 'best_model_step_*'
537
+ """
538
+ output_dir = Path(self.config.output_dir)
539
+ if not output_dir.exists():
540
+ return
541
+
542
+ # List all subdirectories
543
+ subdirs = [p for p in output_dir.iterdir() if p.is_dir()]
544
+ if not subdirs:
545
+ return
546
+
547
+ # Identify protected directories
548
+ protected = set()
549
+
550
+ # Always protect 'final_model' if present
551
+ final_dir = output_dir / "final_model"
552
+ if final_dir.exists() and final_dir.is_dir():
553
+ protected.add(final_dir.resolve())
554
+
555
+ if self.config.protect_latest_epoch_step:
556
+ # Latest best_model_epoch_*
557
+ epoch_dirs = [d for d in subdirs if re.match(r"best_model_epoch_\d+$", d.name)]
558
+ if epoch_dirs:
559
+ latest_epoch = max(epoch_dirs, key=lambda d: d.stat().st_mtime)
560
+ protected.add(latest_epoch.resolve())
561
+
562
+ # Latest best_model_step_*
563
+ step_dirs = [d for d in subdirs if re.match(r"best_model_step_\d+$", d.name)]
564
+ if step_dirs:
565
+ latest_step = max(step_dirs, key=lambda d: d.stat().st_mtime)
566
+ protected.add(latest_step.resolve())
567
+
568
+ # Candidates counted toward limit: all except 'final_model'
569
+ counted = [d for d in subdirs if d.resolve() != final_dir.resolve()]
570
+
571
+ # Nothing to do if within limit
572
+ if len(counted) <= self.config.max_checkpoints:
573
+ return
574
+
575
+ # Sort by mtime (oldest first)
576
+ counted_sorted = sorted(counted, key=lambda d: d.stat().st_mtime)
577
+
578
+ # Iteratively delete oldest non-protected until within limit
579
+ to_delete = []
580
+ current = len(counted)
581
+ for d in counted_sorted:
582
+ if current <= self.config.max_checkpoints:
583
+ break
584
+ if d.resolve() in protected:
585
+ continue
586
+ to_delete.append(d)
587
+ current -= 1
588
+
589
+ # If still above limit because everything old was protected,
590
+ # continue deleting oldest even if protected EXCEPT final_model,
591
+ # but try to avoid removing the most recent protected items by re-check.
592
+ if current > self.config.max_checkpoints:
593
+ # Recompute deletable set excluding final_model only
594
+ extras = [d for d in counted_sorted if d.resolve() != final_dir.resolve() and d not in to_delete]
595
+ for d in extras:
596
+ if current <= self.config.max_checkpoints:
597
+ break
598
+ # Do not delete the most recent protected epoch/step if possible
599
+ if d.resolve() in protected:
600
+ continue
601
+ to_delete.append(d)
602
+ current -= 1
603
+
604
+ # Execute deletions
605
+ for d in to_delete:
606
+ try:
607
+ shutil.rmtree(d)
608
+ logger.info(f"Deleted old checkpoint: {d}")
609
+ except Exception as e:
610
+ logger.warning(f"Failed to delete {d}: {e}")
611
+
612
+ def save_model(self, name: str):
613
+ """Save model checkpoint"""
614
+ save_path = Path(self.config.output_dir) / name
615
+ save_path.mkdir(parents=True, exist_ok=True)
616
+
617
+ # Only save model state dict to keep file size manageable
618
+ torch.save(self.model.state_dict(), save_path / 'pytorch_model.bin')
619
+
620
+ # Save config separately
621
+ with open(save_path / 'training_config.json', 'w') as f:
622
+ json.dump(vars(self.config), f, indent=4)
623
+
624
+ logger.info(f"Model saved to {save_path}")
625
+
626
+ # Enforce retention after each save
627
+ self._enforce_checkpoint_limit()
628
+
629
+ def train(self):
630
+ """Main training loop"""
631
+ logger.info("Starting training...")
632
+ logger.info(f"Training samples: {len(self.train_dataset)}")
633
+ logger.info(f"Validation samples: {len(self.val_dataset)}")
634
+
635
+ # Calculate total optimization steps accurately
636
+ total_optimization_steps = (len(self.train_loader) + self.config.gradient_accumulation_steps - 1) // self.config.gradient_accumulation_steps * self.config.num_epochs
637
+ logger.info(f"Total optimization steps: {total_optimization_steps}")
638
+ logger.info(f"Early stopping: monitoring validation loss with patience={self.config.patience}")
639
+
640
+ for epoch in range(self.config.num_epochs):
641
+ logger.info(f"\n{'='*50}")
642
+ logger.info(f"Epoch {epoch + 1}/{self.config.num_epochs}")
643
+
644
+ # Train
645
+ avg_train_loss = self.train_epoch(epoch + 1)
646
+ logger.info(f"Average training loss: {avg_train_loss:.4f}")
647
+
648
+ # Check if early stopping was already triggered mid-epoch from train_epoch
649
+ if self.patience_counter >= self.config.patience:
650
+ logger.info("Training stopped due to early stopping during epoch.")
651
+ break
652
+
653
+ # Evaluate at end of epoch if not already stopped
654
+ eval_metrics = self.evaluate()
655
+ logger.info(f"Epoch {epoch + 1} - Eval metrics:")
656
+ for key, value in eval_metrics.items():
657
+ logger.info(f" {key}: {value:.4f}")
658
+
659
+ # Check for early stopping at epoch level
660
+ current_val_loss = eval_metrics['loss']
661
+ if current_val_loss < self.best_val_loss - self.config.min_delta:
662
+ self.best_val_loss = current_val_loss
663
+ self.patience_counter = 0
664
+ self.save_model(f"best_model_epoch_{epoch + 1}")
665
+ logger.info(f"New best validation loss at epoch end: {self.best_val_loss:.4f}")
666
+ else:
667
+ self.patience_counter += 1
668
+ logger.info(f"No improvement in validation loss. Patience: {self.patience_counter}/{self.config.patience}")
669
+
670
+ # Check for early stopping
671
+ if self.patience_counter >= self.config.patience:
672
+ logger.info("Training stopped due to early stopping")
673
+ break
674
+
675
+ # Save final model
676
+ self.save_model("final_model")
677
+
678
+ logger.info("Training completed!")
679
+ logger.info(f"Best validation loss: {self.best_val_loss:.4f}")
680
+ wandb.finish()
681
+
682
+
683
+ def main():
684
+ """Main function"""
685
+ config = TrainingConfig()
686
+
687
+ # Optimized for RTX 4090 with BF16
688
+ # You can override config here based on your VRAM usage:
689
+ # config.batch_size = 32 # RTX 4090 can handle larger batches with 24GB VRAM
690
+ # config.gradient_accumulation_steps = 1 # May not need accumulation
691
+ # config.learning_rate = 1e-5 # Sometimes better for fine-tuning
692
+
693
+ trainer = Trainer(config)
694
+ trainer.train()
695
+
696
+
697
+ if __name__ == "__main__":
698
+ main()
training_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "microsoft/deberta-v3-large",
3
+ "num_labels": 2,
4
+ "train_file": "train_windows.jsonl",
5
+ "val_file": "val_windows.jsonl",
6
+ "max_length": 512,
7
+ "batch_size": 8,
8
+ "gradient_accumulation_steps": 8,
9
+ "num_epochs": 3,
10
+ "learning_rate": 1e-06,
11
+ "warmup_ratio": 0.1,
12
+ "weight_decay": 0.01,
13
+ "max_grad_norm": 1.0,
14
+ "label_smoothing": 0.0,
15
+ "device": "cuda",
16
+ "num_workers": 0,
17
+ "seed": 42,
18
+ "bf16": true,
19
+ "logging_steps": 1,
20
+ "eval_steps": 5000,
21
+ "save_steps": 10000,
22
+ "output_dir": "./deberta_link_output",
23
+ "wandb_project": "deberta-link-classification",
24
+ "wandb_name": "deberta-v3-large-link-tokens",
25
+ "patience": 2,
26
+ "min_delta": 0.0001,
27
+ "max_checkpoints": 5,
28
+ "protect_latest_epoch_step": true
29
+ }