annnli commited on
Commit
8c38f91
·
verified ·
1 Parent(s): da73d19

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_roberta_cl.py +46 -389
modeling_roberta_cl.py CHANGED
@@ -2,36 +2,10 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import torch.distributed as dist
5
- from torch import Tensor
6
 
7
  import transformers
8
- from transformers import RobertaTokenizer
9
- from transformers.models.roberta.modeling_roberta import RobertaForSequenceClassification, RobertaClassificationHead, RobertaPreTrainedModel, RobertaModel, RobertaLMHead
10
- from transformers.models.qwen2.modeling_qwen2 import Qwen2PreTrainedModel, Qwen2Model
11
- from transformers.activations import gelu
12
- from transformers.file_utils import (
13
- add_code_sample_docstrings,
14
- add_start_docstrings,
15
- add_start_docstrings_to_model_forward,
16
- replace_return_docstrings,
17
- )
18
- from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions
19
-
20
- class MLPLayer(nn.Module):
21
- """
22
- Head for getting sentence representations over RoBERTa/BERT's CLS representation.
23
- """
24
-
25
- def __init__(self, config):
26
- super().__init__()
27
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
28
- self.activation = nn.Tanh()
29
-
30
- def forward(self, features, **kwargs):
31
- x = self.dense(features)
32
- x = self.activation(x)
33
-
34
- return x
35
 
36
  class ResidualBlock(nn.Module):
37
  def __init__(self, dim):
@@ -98,70 +72,6 @@ class RobertaClassificationHeadForEmbedding(RobertaClassificationHead):
98
  # x = self.dropout(x)
99
  # x = self.out_proj(x)
100
  return x
101
-
102
-
103
- class QueryHead(nn.Module):
104
- def __init__(self, hidden_size):
105
- super(QueryHead, self).__init__()
106
- # Learnable query vector
107
- self.query = nn.Parameter(torch.randn(hidden_size))
108
-
109
- def forward(self, hidden_states, attention_mask=None):
110
- """
111
- Args:
112
- hidden_states: Tensor of shape (batch_size, seq_length, hidden_size)
113
- attention_mask: Tensor of shape (batch_size, seq_length) with 1 for real tokens and 0 for padding tokens.
114
- Returns:
115
- sequence_embedding: Tensor of shape (batch_size, hidden_size)
116
- """
117
- # Compute raw attention scores
118
- attention_scores = torch.matmul(hidden_states, self.query) # (batch_size, seq_length)
119
-
120
- # Apply attention mask (set padding positions to large negative value before softmax)
121
- if attention_mask is not None:
122
- attention_scores = attention_scores.masked_fill(attention_mask == 0, -1e4)
123
-
124
- # Normalize attention scores
125
- attention_weights = F.softmax(attention_scores, dim=1) # (batch_size, seq_length)
126
-
127
- # Aggregate hidden states
128
- sequence_embedding = torch.matmul(attention_weights.unsqueeze(1), hidden_states).squeeze(1) # (batch_size, hidden_size)
129
-
130
- return sequence_embedding
131
-
132
-
133
- class AttentionPooling(nn.Module):
134
- def __init__(self, hidden_dim):
135
- super().__init__()
136
- self.key_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) # Key matrix W_K
137
- self.value_proj = nn.Linear(hidden_dim, hidden_dim, bias=False) # Value matrix W_V
138
- self.query = nn.Parameter(torch.randn(hidden_dim)) # Learnable query vector
139
-
140
- def forward(self, x, attention_mask=None):
141
- """
142
- Args:
143
- x: Tensor of shape (B, L, H), the last hidden layer output.
144
- attention_mask: Tensor of shape (B, L) with 1 for real tokens and 0 for padding tokens.
145
- Returns:
146
- pooled_output: Tensor of shape (B, H), the pooled sequence embedding.
147
- """
148
- K = self.key_proj(x) # (B, L, H)
149
- V = self.value_proj(x) # (B, L, H)
150
-
151
- # Compute attention scores
152
- attn_scores = torch.matmul(K, self.query) / (K.shape[-1] ** 0.5) # (B, L)
153
-
154
- # Apply attention mask (set padding tokens to large negative value)
155
- if attention_mask is not None:
156
- attn_scores = attn_scores.masked_fill(attention_mask == 0, -1e4)
157
-
158
- attn_weights = F.softmax(attn_scores, dim=1) # (B, L)
159
-
160
- # Weighted sum of values
161
- pooled_output = torch.matmul(attn_weights.unsqueeze(1), V).squeeze(1) # (B, H)
162
- # pooled_output = torch.sum(attn_weights.unsqueeze(-1) * V, dim=1) # (B, H)
163
-
164
- return pooled_output
165
 
166
  def cl_init(cls, config):
167
  """
@@ -194,8 +104,6 @@ def cl_forward(cls,
194
  output_attentions=None,
195
  output_hidden_states=None,
196
  return_dict=None,
197
- mlm_input_ids=None,
198
- mlm_labels=None,
199
  latter_sentiment_spoof_mask=None,
200
  ):
201
  return_dict = return_dict if return_dict is not None else cls.config.use_return_dict
@@ -204,97 +112,29 @@ def cl_forward(cls,
204
  # original + cls.model_args.num_paraphrased + cls.model_args.num_negative
205
  num_sent = input_ids.size(1)
206
 
207
- # # input_ids: (bs, num_sent, len)
208
- # # random downsample one paraphrased sentence from sentences index in [1, cls.model_args.num_paraphrased-1]
209
- # # randomly generate one index from [1, cls.model_args.num_paraphrased-1]
210
- # # exclude tensor [:, index, :] from input_ids
211
- # paraphrased_idx = torch.randint(1, cls.model_args.num_paraphrased, (batch_size,))
212
- # mask = torch.ones_like(input_ids, dtype=torch.bool)
213
- # for i in range(batch_size):
214
- # mask[i, paraphrased_idx[i], :] = False
215
- # input_ids = input_ids[mask].view(batch_size, num_sent - 1, -1)
216
- # attention_mask = attention_mask[mask].view(batch_size, num_sent - 1, -1)
217
- # num_paraphrased = cls.model_args.num_paraphrased - 1
218
- # num_sent -= 1
219
- # if token_type_ids is not None:
220
- # token_type_ids = token_type_ids[mask].view(batch_size, num_sent - 1, -1)
221
-
222
- mlm_outputs = None
223
  # Flatten input for encoding
224
  input_ids = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len)
225
  attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len)
226
  if token_type_ids is not None:
227
  token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len)
228
 
229
- if 'roberta' in cls.model_args.model_name_or_path:
230
- # Get raw embeddings
231
- outputs = cls.roberta(
232
- input_ids,
233
- attention_mask=attention_mask,
234
- token_type_ids=token_type_ids,
235
- position_ids=position_ids,
236
- head_mask=head_mask,
237
- inputs_embeds=inputs_embeds,
238
- output_attentions=output_attentions,
239
- output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False,
240
- return_dict=True,
241
- )
242
-
243
- # MLM auxiliary objective
244
- if mlm_input_ids is not None:
245
- mlm_input_ids = mlm_input_ids.view((-1, mlm_input_ids.size(-1)))
246
- mlm_outputs = cls.roberta(
247
- mlm_input_ids,
248
- attention_mask=attention_mask,
249
- token_type_ids=token_type_ids,
250
- position_ids=position_ids,
251
- head_mask=head_mask,
252
- inputs_embeds=inputs_embeds,
253
- output_attentions=output_attentions,
254
- output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False,
255
- return_dict=True,
256
- )
257
-
258
- # Pooling
259
- sequence_output = outputs[0] # (bs*num_sent, seq_len, hidden)
260
- pooler_output = cls.classifier(sequence_output) # (bs*num_sent, hidden)
261
- pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden)
262
-
263
- elif 'qwen2' in cls.model_args.model_name_or_path.lower():
264
- def last_token_pool(last_hidden_states: Tensor,
265
- attention_mask: Tensor) -> Tensor:
266
- left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
267
- if left_padding:
268
- return last_hidden_states[:, -1]
269
- else:
270
- sequence_lengths = attention_mask.sum(dim=1) - 1
271
- batch_size = last_hidden_states.shape[0]
272
- return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
273
-
274
- outputs = cls.model(
275
- input_ids,
276
- attention_mask=attention_mask,
277
- token_type_ids=token_type_ids,
278
- position_ids=position_ids,
279
- head_mask=head_mask,
280
- inputs_embeds=inputs_embeds,
281
- output_attentions=output_attentions,
282
- output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False,
283
- return_dict=True,
284
- )
285
 
286
- if cls.model_args.pooler_type in ['query', 'attention']:
287
- pooler_output = cls.pool(outputs.last_hidden_state, attention_mask)
288
- elif cls.model_args.pooler_type == 'last':
289
- pooler_output = last_token_pool(outputs.last_hidden_state, attention_mask)
290
- else:
291
- raise NotImplementedError
292
- # normalize embeddings
293
- pooler_output = F.normalize(pooler_output, p=2, dim=1)
294
-
295
- pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden_states)
296
- else:
297
- raise NotImplementedError
298
 
299
  # Mapping
300
  pooler_output = cls.map(pooler_output) # (bs, num_sent, hidden_states)
@@ -310,11 +150,6 @@ def cl_forward(cls,
310
  # Gather all embeddings if using distributed training
311
  if dist.is_initialized() and cls.training:
312
  raise NotImplementedError
313
-
314
- # straight-through estimate sign function
315
- def sign_ste(x):
316
- x_nogradient = x.detach()
317
- return x + x.sign() - x_nogradient
318
 
319
  # get sign value before calculating similarity
320
  original = torch.tanh(original * 1000)
@@ -325,61 +160,21 @@ def cl_forward(cls,
325
  for cname, n in zip(spoofing_cnames, negative_list):
326
  negative_dict[cname] = n
327
 
328
- # z1 = sign_ste(z1)
329
- # z2_list = [sign_ste(z2) for z2 in z2_list]
330
- # z3_list = [sign_ste(z3) for z3 in z3_list]
331
-
332
- # Compute contrastive loss
333
- if cls.model_args.cl_weight != 0:
334
- negative_weight = cls.model_args.hard_negative_weight
335
- ori_ori_cos = cls.sim(original.unsqueeze(1), original.unsqueeze(0)) # (bs, bs)
336
- ori_ori_cos_removed = remove_diagonal_elements(ori_ori_cos) # (bs, bs-1)
337
- ori_para_cos_list = [cls.sim(original, p).unsqueeze(1) for p in paraphrase_list] # [(bs, 1)] * num_paraphrased
338
- ori_neg_cos_list = [cls.sim(original, n).unsqueeze(1) for n in negative_list] # [(bs,1)] * num_negative
339
- ori_neg_cos_dict = {}
340
- for cname, n in zip(spoofing_cnames, ori_neg_cos_list):
341
- ori_neg_cos_dict[cname] = n
342
-
343
- loss_cl = 0
344
- for i in range(batch_size):
345
- ori = ori_ori_cos_removed[i].sum()
346
- neg = 0
347
  for cname in spoofing_cnames:
348
  if cname == 'latter_sentiment_spoof_0' and latter_sentiment_spoof_mask[i] == 0:
349
  continue
350
- neg += ori_neg_cos_dict[cname][i]
351
- for j in range(cls.model_args.num_paraphrased):
352
- pos = ori_para_cos_list[j][i]
353
- denominator = ori + pos + negative_weight * neg
354
- fraction = pos / (ori + pos + negative_weight * neg)
355
- loss_cl -= torch.log(fraction)
356
- loss_cl /= (batch_size * cls.model_args.num_paraphrased)
357
 
358
- # Calculate triplet loss
359
- if cls.model_args.tl_weight != 0:
360
- loss_triplet = 0
361
- for i in range(batch_size):
362
- for j in range(cls.model_args.num_paraphrased):
363
- for cname in spoofing_cnames:
364
- if cname == 'latter_sentiment_spoof_0' and latter_sentiment_spoof_mask[i] == 0:
365
- continue
366
- ori = original[i]
367
- pos = paraphrase_list[j][i]
368
- neg = negative_dict[cname][i]
369
- loss_triplet += F.relu(cls.sim(ori, neg) * cls.model_args.temp - cls.sim(ori, pos) * cls.model_args.temp + cls.model_args.margin)
370
- loss_triplet /= (batch_size * cls.model_args.num_paraphrased * len(spoofing_cnames))
371
-
372
- # Calculate loss for MLM
373
- if mlm_outputs is not None and mlm_labels is not None:
374
- raise NotImplementedError
375
- # mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1))
376
- # prediction_scores = cls.lm_head(mlm_outputs.last_hidden_state)
377
- # masked_lm_loss = loss_fct(prediction_scores.view(-1, cls.config.vocab_size), mlm_labels.view(-1))
378
- # loss_cl = loss_cl + cls.model_args.mlm_weight * masked_lm_loss
379
-
380
  # Calculate loss for uniform perturbation and unbiased token preference
381
  def sign_loss(x):
382
- # smooth_sign = sign_ste(x)
383
  row = torch.abs(torch.mean(torch.mean(x, dim=0)))
384
  col = torch.abs(torch.mean(torch.mean(x, dim=1)))
385
  return (row + col)/2
@@ -390,8 +185,6 @@ def cl_forward(cls,
390
  loss_3_list = [cls.sim(original, p).unsqueeze(1) for p in paraphrase_list] # [(bs, 1)] * num_paraphrased
391
  loss_3_tensor = torch.cat(loss_3_list, dim=1) # (bs, num_paraphrased)
392
  loss_3 = loss_3_tensor.mean() * cls.model_args.temp
393
- # debug:
394
- # loss_3 = loss_3[valid_for_loss3.bool()]
395
 
396
  # calculate loss_sent: similarity between original and sentiment spoofed text
397
  negative_sample_loss = {}
@@ -409,14 +202,7 @@ def cl_forward(cls,
409
  ori_ori_cos_removed = remove_diagonal_elements(ori_ori_cos) # (bs, bs-1)
410
  loss_5 = ori_ori_cos_removed.mean() * cls.model_args.temp
411
 
412
- if cls.model_args.cl_weight != 0 and cls.model_args.tl_weight != 0:
413
- loss = loss_gr + cls.model_args.cl_weight * loss_cl + cls.model_args.tl_weight * loss_triplet
414
- elif cls.model_args.cl_weight != 0 and cls.model_args.tl_weight == 0:
415
- loss = loss_gr + cls.model_args.cl_weight * loss_cl
416
- elif cls.model_args.cl_weight == 0 and cls.model_args.tl_weight != 0:
417
- loss = loss_gr + cls.model_args.tl_weight * loss_triplet
418
- else:
419
- raise ValueError("Both contrastive loss and triplet loss weights are zero.")
420
 
421
  result = {
422
  'loss': loss,
@@ -431,10 +217,7 @@ def cl_forward(cls,
431
  key = f"sim_{cname.replace('_spoof_0', '')}"
432
  result[key] = l
433
 
434
- if cls.model_args.cl_weight != 0:
435
- result['loss_cl'] = loss_cl
436
- if cls.model_args.tl_weight != 0:
437
- result['loss_tl'] = loss_triplet
438
 
439
  if not return_dict:
440
  raise NotImplementedError
@@ -455,60 +238,23 @@ def sentemb_forward(
455
  output_attentions=None,
456
  output_hidden_states=None,
457
  return_dict=None,
458
- lambda_1=1.0,
459
- lambda_2=1.0,
460
  ):
461
 
462
  return_dict = return_dict if return_dict is not None else cls.config.use_return_dict
463
 
464
- if 'roberta' in cls.model_args.model_name_or_path:
465
- outputs = cls.roberta(
466
- input_ids,
467
- attention_mask=attention_mask,
468
- token_type_ids=token_type_ids,
469
- position_ids=position_ids,
470
- head_mask=head_mask,
471
- inputs_embeds=inputs_embeds,
472
- output_attentions=output_attentions,
473
- output_hidden_states=False,
474
- return_dict=True,
475
- )
476
- sequence_output = outputs[0]
477
- pooler_output = cls.classifier(sequence_output)
478
- elif 'qwen2' in cls.model_args.model_name_or_path.lower():
479
- def last_token_pool(last_hidden_states: Tensor,
480
- attention_mask: Tensor) -> Tensor:
481
- left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
482
- if left_padding:
483
- return last_hidden_states[:, -1]
484
- else:
485
- sequence_lengths = attention_mask.sum(dim=1) - 1
486
- batch_size = last_hidden_states.shape[0]
487
- return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
488
-
489
- outputs = cls.model(
490
- input_ids,
491
- attention_mask=attention_mask,
492
- token_type_ids=token_type_ids,
493
- position_ids=position_ids,
494
- head_mask=head_mask,
495
- inputs_embeds=inputs_embeds,
496
- output_attentions=output_attentions,
497
- output_hidden_states=True,
498
- return_dict=True,
499
- )
500
-
501
- if cls.model_args.pooler_type in ['query', 'attention']:
502
- pooler_output = cls.pool(outputs.last_hidden_state, attention_mask)
503
- elif cls.model_args.pooler_type == 'last':
504
- pooler_output = last_token_pool(outputs.last_hidden_state, attention_mask)
505
- else:
506
- raise NotImplementedError
507
- # normalize embeddings
508
- pooler_output = F.normalize(pooler_output, p=2, dim=1)
509
- else:
510
- raise NotImplementedError
511
-
512
 
513
  # Mapping
514
  mapping_output = cls.map(pooler_output)
@@ -530,103 +276,18 @@ class RobertaForCL(RobertaForSequenceClassification):
530
 
531
  def __init__(self, config, *model_args, **model_kargs):
532
  super().__init__(config)
533
- self.model_args = model_kargs["model_args"]
534
 
535
  self.classifier = RobertaClassificationHeadForEmbedding(config)
536
 
537
- if self.model_args.do_mlm:
538
- self.lm_head = RobertaLMHead(config)
539
 
540
  self.map = SemanticModel(input_dim=768)
541
- cl_init(self, config)
542
-
543
- if self.model_args.freeze_base:
544
- # Freeze RoBERTa encoder parameters
545
- for param in self.roberta.parameters():
546
- param.requires_grad = False
547
- for param in self.classifier.parameters():
548
- param.requires_grad = False
549
-
550
  # Initialize weights and apply final processing
551
  self.post_init()
552
 
553
- def initialize_mlp_weights(self, pretrained_model_state_dict):
554
- """
555
- Initialize MLP weights using the pretrained classifier's weights.
556
- """
557
- self.mlp.dense.weight.data = pretrained_model_state_dict.classifier.dense.weight.data.clone()
558
- self.mlp.dense.bias.data = pretrained_model_state_dict.classifier.dense.bias.data.clone()
559
-
560
- def forward(self,
561
- input_ids=None,
562
- attention_mask=None,
563
- token_type_ids=None,
564
- position_ids=None,
565
- head_mask=None,
566
- inputs_embeds=None,
567
- labels=None,
568
- output_attentions=None,
569
- output_hidden_states=None,
570
- return_dict=None,
571
- sent_emb=False,
572
- mlm_input_ids=None,
573
- mlm_labels=None,
574
- latter_sentiment_spoof_mask=None,
575
- ):
576
- if sent_emb:
577
- return sentemb_forward(self,
578
- input_ids=input_ids,
579
- attention_mask=attention_mask,
580
- token_type_ids=token_type_ids,
581
- position_ids=position_ids,
582
- head_mask=head_mask,
583
- inputs_embeds=inputs_embeds,
584
- labels=labels,
585
- output_attentions=output_attentions,
586
- output_hidden_states=output_hidden_states,
587
- return_dict=return_dict,
588
- )
589
- else:
590
- return cl_forward(self,
591
- input_ids=input_ids,
592
- attention_mask=attention_mask,
593
- token_type_ids=token_type_ids,
594
- position_ids=position_ids,
595
- head_mask=head_mask,
596
- inputs_embeds=inputs_embeds,
597
- labels=labels,
598
- output_attentions=output_attentions,
599
- output_hidden_states=output_hidden_states,
600
- return_dict=return_dict,
601
- mlm_input_ids=mlm_input_ids,
602
- mlm_labels=mlm_labels,
603
- latter_sentiment_spoof_mask=latter_sentiment_spoof_mask,
604
- )
605
-
606
- class Qwen2ForCL(Qwen2PreTrainedModel):
607
- _keys_to_ignore_on_load_missing = [r"position_ids"]
608
-
609
- def __init__(self, config, *model_args, **model_kargs):
610
- super().__init__(config)
611
- self.model_args = model_kargs["model_args"]
612
- self.model = Qwen2Model(config)
613
-
614
- if self.model_args.pooler_type == 'query':
615
- self.pool = QueryHead(config.hidden_size)
616
- elif self.model_args.pooler_type == 'attention':
617
- self.pool = AttentionPooling(config.hidden_size)
618
-
619
- # if self.model_args.do_mlm:
620
- # self.lm_head = RobertaLMHead(config)
621
-
622
- cl_init(self, config)
623
- self.map = SemanticModel(input_dim=1536)
624
-
625
- if self.model_args.freeze_base:
626
- # Freeze Qwen parameters
627
- for param in self.model.parameters():
628
- param.requires_grad = False
629
-
630
  def forward(self,
631
  input_ids=None,
632
  attention_mask=None,
@@ -639,8 +300,6 @@ class Qwen2ForCL(Qwen2PreTrainedModel):
639
  output_hidden_states=None,
640
  return_dict=None,
641
  sent_emb=False,
642
- mlm_input_ids=None,
643
- mlm_labels=None,
644
  latter_sentiment_spoof_mask=None,
645
  ):
646
  if sent_emb:
@@ -668,8 +327,6 @@ class Qwen2ForCL(Qwen2PreTrainedModel):
668
  output_attentions=output_attentions,
669
  output_hidden_states=output_hidden_states,
670
  return_dict=return_dict,
671
- mlm_input_ids=mlm_input_ids,
672
- mlm_labels=mlm_labels,
673
  latter_sentiment_spoof_mask=latter_sentiment_spoof_mask,
674
  )
675
 
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import torch.distributed as dist
 
5
 
6
  import transformers
7
+ from transformers.models.roberta.modeling_roberta import RobertaForSequenceClassification, RobertaClassificationHead
8
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  class ResidualBlock(nn.Module):
11
  def __init__(self, dim):
 
72
  # x = self.dropout(x)
73
  # x = self.out_proj(x)
74
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  def cl_init(cls, config):
77
  """
 
104
  output_attentions=None,
105
  output_hidden_states=None,
106
  return_dict=None,
 
 
107
  latter_sentiment_spoof_mask=None,
108
  ):
109
  return_dict = return_dict if return_dict is not None else cls.config.use_return_dict
 
112
  # original + cls.model_args.num_paraphrased + cls.model_args.num_negative
113
  num_sent = input_ids.size(1)
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  # Flatten input for encoding
116
  input_ids = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len)
117
  attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len)
118
  if token_type_ids is not None:
119
  token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len)
120
 
121
+ # Get raw embeddings
122
+ outputs = cls.roberta(
123
+ input_ids,
124
+ attention_mask=attention_mask,
125
+ token_type_ids=token_type_ids,
126
+ position_ids=position_ids,
127
+ head_mask=head_mask,
128
+ inputs_embeds=inputs_embeds,
129
+ output_attentions=output_attentions,
130
+ output_hidden_states=False,
131
+ return_dict=True,
132
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ # Pooling
135
+ sequence_output = outputs[0] # (bs*num_sent, seq_len, hidden)
136
+ pooler_output = cls.classifier(sequence_output) # (bs*num_sent, hidden)
137
+ pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden)
 
 
 
 
 
 
 
 
138
 
139
  # Mapping
140
  pooler_output = cls.map(pooler_output) # (bs, num_sent, hidden_states)
 
150
  # Gather all embeddings if using distributed training
151
  if dist.is_initialized() and cls.training:
152
  raise NotImplementedError
 
 
 
 
 
153
 
154
  # get sign value before calculating similarity
155
  original = torch.tanh(original * 1000)
 
160
  for cname, n in zip(spoofing_cnames, negative_list):
161
  negative_dict[cname] = n
162
 
163
+ # Calculate triplet loss
164
+ loss_triplet = 0
165
+ for i in range(batch_size):
166
+ for j in range(cls.model_args.num_paraphrased):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  for cname in spoofing_cnames:
168
  if cname == 'latter_sentiment_spoof_0' and latter_sentiment_spoof_mask[i] == 0:
169
  continue
170
+ ori = original[i]
171
+ pos = paraphrase_list[j][i]
172
+ neg = negative_dict[cname][i]
173
+ loss_triplet += F.relu(cls.sim(ori, neg) * cls.model_args.temp - cls.sim(ori, pos) * cls.model_args.temp + cls.model_args.margin)
174
+ loss_triplet /= (batch_size * cls.model_args.num_paraphrased * len(spoofing_cnames))
 
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  # Calculate loss for uniform perturbation and unbiased token preference
177
  def sign_loss(x):
 
178
  row = torch.abs(torch.mean(torch.mean(x, dim=0)))
179
  col = torch.abs(torch.mean(torch.mean(x, dim=1)))
180
  return (row + col)/2
 
185
  loss_3_list = [cls.sim(original, p).unsqueeze(1) for p in paraphrase_list] # [(bs, 1)] * num_paraphrased
186
  loss_3_tensor = torch.cat(loss_3_list, dim=1) # (bs, num_paraphrased)
187
  loss_3 = loss_3_tensor.mean() * cls.model_args.temp
 
 
188
 
189
  # calculate loss_sent: similarity between original and sentiment spoofed text
190
  negative_sample_loss = {}
 
202
  ori_ori_cos_removed = remove_diagonal_elements(ori_ori_cos) # (bs, bs-1)
203
  loss_5 = ori_ori_cos_removed.mean() * cls.model_args.temp
204
 
205
+ loss = loss_gr + loss_triplet
 
 
 
 
 
 
 
206
 
207
  result = {
208
  'loss': loss,
 
217
  key = f"sim_{cname.replace('_spoof_0', '')}"
218
  result[key] = l
219
 
220
+ result['loss_tl'] = loss_triplet
 
 
 
221
 
222
  if not return_dict:
223
  raise NotImplementedError
 
238
  output_attentions=None,
239
  output_hidden_states=None,
240
  return_dict=None,
 
 
241
  ):
242
 
243
  return_dict = return_dict if return_dict is not None else cls.config.use_return_dict
244
 
245
+ outputs = cls.roberta(
246
+ input_ids,
247
+ attention_mask=attention_mask,
248
+ token_type_ids=token_type_ids,
249
+ position_ids=position_ids,
250
+ head_mask=head_mask,
251
+ inputs_embeds=inputs_embeds,
252
+ output_attentions=output_attentions,
253
+ output_hidden_states=False,
254
+ return_dict=True,
255
+ )
256
+ sequence_output = outputs[0]
257
+ pooler_output = cls.classifier(sequence_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  # Mapping
260
  mapping_output = cls.map(pooler_output)
 
276
 
277
  def __init__(self, config, *model_args, **model_kargs):
278
  super().__init__(config)
279
+ self.model_args = model_kargs.get("model_args", None)
280
 
281
  self.classifier = RobertaClassificationHeadForEmbedding(config)
282
 
283
+ if self.model_args:
284
+ cl_init(self, config)
285
 
286
  self.map = SemanticModel(input_dim=768)
287
+
 
 
 
 
 
 
 
 
288
  # Initialize weights and apply final processing
289
  self.post_init()
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  def forward(self,
292
  input_ids=None,
293
  attention_mask=None,
 
300
  output_hidden_states=None,
301
  return_dict=None,
302
  sent_emb=False,
 
 
303
  latter_sentiment_spoof_mask=None,
304
  ):
305
  if sent_emb:
 
327
  output_attentions=output_attentions,
328
  output_hidden_states=output_hidden_states,
329
  return_dict=return_dict,
 
 
330
  latter_sentiment_spoof_mask=latter_sentiment_spoof_mask,
331
  )
332