iitolstykh commited on
Commit
ef31f0e
·
verified ·
1 Parent(s): 81db16a

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_gigacheck.py +28 -0
  2. modeling_gigacheck.py +223 -0
configuration_gigacheck.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Any
2
+ from transformers import MistralConfig
3
+
4
+
5
+ class GigaCheckConfig(MistralConfig):
6
+ def __init__(
7
+ self,
8
+ with_detr: bool = False,
9
+ detr_config: Optional[Dict[str, Any]] = None,
10
+ freeze_backbone: bool = False,
11
+ id2label: Dict[int, str] = None,
12
+ num_labels: int = 2,
13
+ max_length: int = 1024,
14
+ conf_interval_thresh=0.8,
15
+ **kwargs
16
+ ):
17
+ super().__init__(**kwargs)
18
+
19
+ self.with_detr = with_detr
20
+ self.detr_config = detr_config
21
+ self.freeze_backbone = freeze_backbone
22
+ self.id2label = id2label
23
+ self.num_labels = num_labels
24
+ self.max_length = max_length
25
+ self.conf_interval_thresh = conf_interval_thresh
26
+
27
+ if self.id2label:
28
+ self.id2label = {int(k): v for k, v in self.id2label.items()}
modeling_gigacheck.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer
2
+ from transformers.modeling_outputs import ModelOutput
3
+ from typing import List, Dict, Optional, Union, Tuple
4
+ from dataclasses import dataclass
5
+ import torch
6
+
7
+ from gigacheck.model.mistral_ai_detector import MistralAIDetectorForSequenceClassification
8
+ from gigacheck.model.src.interval_detector.span_utils import span_cxw_to_xx
9
+
10
+ from .configuration_gigacheck import GigaCheckConfig
11
+
12
+
13
+ @dataclass
14
+ class GigaCheckOutput(ModelOutput):
15
+ """
16
+ Output type for GigaCheck model.
17
+
18
+ Args:
19
+ pred_label_ids (torch.Tensor): [Batch] Indices of the predicted classes (Human/AI/Mixed).
20
+ classification_head_probs (torch.Tensor): [Batch, Num_Classes] Softmax probabilities.
21
+ ai_intervals (List[torch.Tensor]): List of length Batch. Each element is a tensor of shape [Num_Intervals, 3]
22
+ containing (start, end, score) for detected AI-generated spans.
23
+ """
24
+ pred_label_ids: Optional[torch.Tensor] = None
25
+ classification_head_probs: Optional[torch.Tensor] = None
26
+ ai_intervals: Optional[List[torch.Tensor]] = None
27
+
28
+
29
+ class GigaCheckForDetection(MistralAIDetectorForSequenceClassification):
30
+ config_class = GigaCheckConfig
31
+
32
+ def __init__(self, config: GigaCheckConfig):
33
+ super().__init__(
34
+ config,
35
+ with_detr = config.with_detr,
36
+ detr_config = config.detr_config,
37
+ ce_weights = None,
38
+ freeze_backbone = False,
39
+ id2label = config.id2label,
40
+ )
41
+ self.trained_classification_head = True
42
+ self._max_len = self.config.max_length
43
+ self.tokenizer = None
44
+ self.conf_interval_thresh = config.conf_interval_thresh
45
+
46
+ @classmethod
47
+ def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): # type: ignore
48
+ """Loads a pretrained GigaCheck model from a local path or the Hugging Face Hub.
49
+
50
+ Args:
51
+ pretrained_model_name_or_path (str): The name or path of the pretrained model.
52
+ model_args: Additional positional arguments passed to parent class.
53
+ kwargs: Additional keyword arguments passed to parent class.
54
+
55
+ Returns:
56
+ GigaCheckForSequenceClassification: The initialized model with loaded weights and initialized tokenizer.
57
+ """
58
+ # set model weights
59
+ model = super().from_pretrained(
60
+ pretrained_model_name_or_path,
61
+ *model_args,
62
+ **kwargs,
63
+ )
64
+
65
+ if model.config.with_detr:
66
+ extractor_dtype = getattr(torch, model.config.detr_config["extractor_dtype"])
67
+ print(f"Using dtype={extractor_dtype} for {type(model.model)}")
68
+ if extractor_dtype == torch.bfloat16:
69
+ model.model.to(torch.bfloat16)
70
+ model.classification_head.to(torch.bfloat16)
71
+
72
+ if model.config.to_dict().get("trained_classification_head", True) is False:
73
+ # when only detr was trained
74
+ model.trained_classification_head = False
75
+
76
+ model.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
77
+
78
+ # Ensure pad token exists
79
+ model.config.pad_token_id = model.tokenizer.pad_token_id \
80
+ if model.tokenizer.pad_token_id is not None else model.tokenizer.unk_token_id
81
+ if model.tokenizer.pad_token_id is None:
82
+ model.tokenizer.pad_token_id = model.tokenizer.unk_token_id
83
+
84
+ model.config.bos_token_id = model.tokenizer.bos_token_id
85
+ model.config.eos_token_id = model.tokenizer.eos_token_id
86
+ model.config.unk_token_id = model.tokenizer.unk_token_id
87
+
88
+ return model
89
+
90
+ def _get_inputs(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
91
+ """
92
+ Tokenizes a batch of texts handling specific truncation logic to preserve exact text length mapping.
93
+ """
94
+ assert self._max_len is not None and self.tokenizer is not None, "Model must be initialized"
95
+
96
+ # 1. Tokenize all texts without special tokens/padding first
97
+ raw_encodings = self.tokenizer(texts, add_special_tokens=False)
98
+
99
+ batch_features = [] # List of dicts for tokenizer.pad
100
+ text_lens = []
101
+
102
+ content_max_len = self._max_len - 2
103
+ bos_id = self.tokenizer.bos_token_id
104
+ eos_id = self.tokenizer.eos_token_id
105
+
106
+ for i, tokens in enumerate(raw_encodings.input_ids):
107
+ if len(tokens) > content_max_len:
108
+ tokens = tokens[:content_max_len]
109
+ # Convert back to string to get the exact character length of the truncated part
110
+ cur_text = self.tokenizer.decode(tokens, skip_special_tokens=True)
111
+ text_len = len(cur_text)
112
+ else:
113
+ # If no truncation, use the original text length
114
+ text_len = len(texts[i])
115
+
116
+ # Construct final token sequence: [BOS] + tokens + [EOS]
117
+ final_tokens = [bos_id] + tokens + [eos_id]
118
+
119
+ # Append as dictionary for tokenizer.pad
120
+ batch_features.append({"input_ids": final_tokens})
121
+ text_lens.append(text_len)
122
+
123
+ # 2. Pad using tokenizer.pad
124
+ padded_output = self.tokenizer.pad(
125
+ batch_features,
126
+ padding=True,
127
+ return_tensors="pt"
128
+ )
129
+
130
+ input_ids = padded_output["input_ids"].to(self.device)
131
+ attention_mask = padded_output["attention_mask"].to(self.device)
132
+
133
+ return input_ids, attention_mask, text_lens
134
+
135
+ @staticmethod
136
+ def _get_ai_intervals(detr_out: Dict[str, torch.Tensor], text_lens: List[int], conf_interval_thresh: float) -> List[torch.Tensor]:
137
+ """
138
+ Converts DETR outputs to absolute text intervals.
139
+ """
140
+ pred_spans = detr_out["pred_spans"] # (batch_size, #queries, 2)
141
+ src_logits = detr_out["pred_logits"] # (batch_size, #queries, #classes=2)
142
+ assert len(text_lens) == pred_spans.shape[0]
143
+
144
+ # Take probs for foreground objects only (ind = 0)
145
+ pred_probs = torch.softmax(src_logits, dim=-1)[:, :, 0:1] # [Batch, Queries, 1]
146
+
147
+ final_preds_batch = []
148
+
149
+ for i, length in enumerate(text_lens):
150
+ # Convert center-width [0,1] to [0, length] absolute start-end
151
+ # pred_spans[i]: [Queries, 2]
152
+ spans_abs = to_absolute(pred_spans[i], length)
153
+
154
+ # Concat spans and scores: [Queries, 3] -> (start, end, score)
155
+ scores = pred_probs[i]
156
+ preds_i = torch.cat([spans_abs, scores], dim=1)
157
+
158
+ # Filter by confidence threshold
159
+ mask = preds_i[:, 2] > conf_interval_thresh
160
+ filtered_preds = preds_i[mask]
161
+
162
+ final_preds_batch.append(filtered_preds)
163
+
164
+ return final_preds_batch
165
+
166
+ def forward(
167
+ self,
168
+ text: Union[str, List[str]],
169
+ return_dict: Optional[bool] = None,
170
+ conf_interval_thresh: float = None,
171
+ ) -> Union[Tuple, GigaCheckOutput]:
172
+
173
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
174
+ conf_interval_thresh = conf_interval_thresh if conf_interval_thresh is not None else self.config.conf_interval_thresh
175
+
176
+ if isinstance(text, str):
177
+ text = [text]
178
+
179
+ input_ids, attention_mask, text_lens = self._get_inputs(text)
180
+
181
+ output = super().forward(
182
+ input_ids=input_ids,
183
+ attention_mask=attention_mask,
184
+ return_dict=True,
185
+ return_detr_output=self.config.with_detr,
186
+ )
187
+
188
+ pred_label_ids = None
189
+ classification_head_probs = None
190
+ ai_intervals = None
191
+
192
+ # 1. Classification Head Processing
193
+ if not self.config.with_detr:
194
+ logits = output.logits
195
+ elif self.trained_classification_head:
196
+ logits, _ = output.logits
197
+ else:
198
+ logits = None
199
+
200
+ if logits is not None:
201
+ # logits: [Batch, NumClasses]
202
+ probs = logits.to(torch.float32).softmax(dim=-1)
203
+ pred_label_ids = torch.argmax(probs, dim=-1) # [Batch]
204
+ classification_head_probs = probs # [Batch, NumClasses]
205
+
206
+ # 2. Interval Detection (DETR) Processing
207
+ if self.config.with_detr:
208
+ _, detr_out = output.logits
209
+ ai_intervals = self._get_ai_intervals(detr_out, text_lens, conf_interval_thresh)
210
+
211
+ if not return_dict:
212
+ return (pred_label_ids, classification_head_probs, ai_intervals)
213
+
214
+ return GigaCheckOutput(
215
+ pred_label_ids=pred_label_ids,
216
+ classification_head_probs=classification_head_probs,
217
+ ai_intervals=ai_intervals,
218
+ )
219
+
220
+
221
+ def to_absolute(pred_spans: torch.Tensor, text_len: int) -> torch.Tensor:
222
+ spans = span_cxw_to_xx(pred_spans) * text_len
223
+ return torch.clamp(spans, 0, text_len)