File size: 12,405 Bytes
edb9bc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
"""
Fine-tune emotion2vec+ on Portuguese BR emotion datasets (VERBO + emoUERJ).

This script implements Option A from academic research:
- Fine-tune emotion2vec+ (SOTA base model)
- Train on VERBO (1,167 samples) + emoUERJ (377 samples)
- Use data augmentation to improve generalization
- Expected improvement: +5-10% accuracy on PT-BR data
"""

import torch
import numpy as np
from transformers import (
    Wav2Vec2Processor,
    Wav2Vec2ForSequenceClassification,
    TrainingArguments,
    Trainer
)
from datasets import load_dataset, concatenate_datasets, Audio
import logging
from pathlib import Path
import argparse
from typing import Dict, List, Any
import librosa
from dataclasses import dataclass

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# Emotion label mapping
EMOTION_LABELS = {
    "neutral": 0,
    "happy": 1,
    "sad": 2,
    "angry": 3,
    "fearful": 4,
    "disgusted": 5,
    "surprised": 6
}

LABEL_TO_ID = EMOTION_LABELS
ID_TO_LABEL = {v: k for k, v in EMOTION_LABELS.items()}


class AudioAugmenter:
    """Data augmentation for audio to improve model robustness."""

    @staticmethod
    def time_stretch(audio: np.ndarray, rate: float = 1.0) -> np.ndarray:
        """Time stretching (slower/faster)."""
        return librosa.effects.time_stretch(audio, rate=rate)

    @staticmethod
    def pitch_shift(audio: np.ndarray, sr: int, n_steps: float = 0.0) -> np.ndarray:
        """Pitch shifting."""
        return librosa.effects.pitch_shift(audio, sr=sr, n_steps=n_steps)

    @staticmethod
    def add_noise(audio: np.ndarray, noise_factor: float = 0.005) -> np.ndarray:
        """Add white noise."""
        noise = np.random.randn(len(audio))
        return audio + noise_factor * noise

    @staticmethod
    def augment(audio: np.ndarray, sr: int, augment_type: str = None) -> np.ndarray:
        """Apply random augmentation."""
        if augment_type == 'time_stretch':
            rate = np.random.uniform(0.9, 1.1)
            return AudioAugmenter.time_stretch(audio, rate)
        elif augment_type == 'pitch_shift':
            n_steps = np.random.uniform(-2, 2)
            return AudioAugmenter.pitch_shift(audio, sr, n_steps)
        elif augment_type == 'noise':
            return AudioAugmenter.add_noise(audio)
        else:
            return audio


def load_verbo_dataset():
    """
    Load VERBO dataset (1,167 samples, 7 emotions).

    VERBO is a Brazilian Portuguese emotional speech corpus.
    Paper: "VERBO: A Corpus for Emotion Recognition in Brazilian Portuguese"

    Note: This dataset may need to be manually downloaded and prepared.
    """
    logger.info("Loading VERBO dataset...")

    try:
        # Try loading from HuggingFace if available
        dataset = load_dataset("VERBO/emotion", split="train")
        logger.info(f"โœ… VERBO loaded: {len(dataset)} samples")
        return dataset
    except:
        logger.warning("โš ๏ธ  VERBO not available on HuggingFace")
        logger.info("Please download VERBO manually from: http://www02.smt.ufrj.br/~verbo/")
        logger.info("Or contact dataset authors for access")
        return None


def load_emouej_dataset():
    """
    Load emoUERJ dataset (377 samples, 4 emotions).

    emoUERJ is a Brazilian Portuguese emotional speech dataset.
    Paper: "emoUERJ: A Deep Learning-Based Emotion Classifier for Brazilian Portuguese"

    Note: This dataset may need to be manually downloaded and prepared.
    """
    logger.info("Loading emoUERJ dataset...")

    try:
        # Try loading from HuggingFace if available
        dataset = load_dataset("emoUERJ/emotion", split="train")
        logger.info(f"โœ… emoUERJ loaded: {len(dataset)} samples")
        return dataset
    except:
        logger.warning("โš ๏ธ  emoUERJ not available on HuggingFace")
        logger.info("Please download emoUERJ manually or contact dataset authors")
        return None


def normalize_emotion_labels(dataset, emotion_field: str = "emotion"):
    """
    Normalize emotion labels to standard 7-class format.

    Maps dataset-specific labels to: neutral, happy, sad, angry, fearful, disgusted, surprised
    """
    def map_label(example):
        emotion = example[emotion_field].lower()

        # Common mappings
        emotion_map = {
            "neutro": "neutral",
            "neutral": "neutral",
            "alegria": "happy",
            "feliz": "happy",
            "happy": "happy",
            "tristeza": "sad",
            "triste": "sad",
            "sad": "sad",
            "raiva": "angry",
            "angry": "angry",
            "medo": "fearful",
            "fearful": "fearful",
            "nojo": "disgusted",
            "disgusted": "disgusted",
            "surpresa": "surprised",
            "surprised": "surprised"
        }

        normalized = emotion_map.get(emotion, "neutral")
        example["label"] = LABEL_TO_ID[normalized]
        example["emotion_text"] = normalized

        return example

    return dataset.map(map_label)


def prepare_dataset(examples, processor, augment: bool = False):
    """Prepare dataset for training."""
    audio_arrays = examples["audio"]

    processed = []
    for audio in audio_arrays:
        array = audio["array"]
        sr = audio["sampling_rate"]

        # Resample to 16kHz if needed
        if sr != 16000:
            array = librosa.resample(array, orig_sr=sr, target_sr=16000)

        # Data augmentation (during training only)
        if augment and np.random.random() < 0.5:
            aug_type = np.random.choice(['time_stretch', 'pitch_shift', 'noise'])
            array = AudioAugmenter.augment(array, 16000, aug_type)

        processed.append(array)

    # Process with Wav2Vec2 processor
    inputs = processor(
        processed,
        sampling_rate=16000,
        return_tensors="pt",
        padding=True,
        max_length=16000 * 10,  # Max 10 seconds
        truncation=True
    )

    inputs["labels"] = examples["label"]
    return inputs


@dataclass
class DataCollatorWithPadding:
    """Custom data collator for audio data."""
    processor: Wav2Vec2Processor

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        # Separate features and labels
        input_values = [{"input_values": feature["input_values"]} for feature in features]
        labels = [feature["labels"] for feature in features]

        # Pad input values
        batch = self.processor.pad(
            input_values,
            padding=True,
            return_tensors="pt"
        )

        batch["labels"] = torch.tensor(labels)
        return batch


def compute_metrics(eval_pred):
    """Compute evaluation metrics."""
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    accuracy = (predictions == labels).mean()

    # Per-class accuracy
    per_class_acc = {}
    for label_id, label_name in ID_TO_LABEL.items():
        mask = labels == label_id
        if mask.sum() > 0:
            per_class_acc[label_name] = (predictions[mask] == labels[mask]).mean()

    return {
        "accuracy": accuracy,
        **{f"accuracy_{k}": v for k, v in per_class_acc.items()}
    }


def main():
    parser = argparse.ArgumentParser(description="Fine-tune emotion2vec on PT-BR datasets")
    parser.add_argument("--base-model", type=str, default="emotion2vec/emotion2vec_plus_large",
                        help="Base model to fine-tune")
    parser.add_argument("--output-dir", type=str, default="models/emotion/emotion2vec_finetuned_ptbr",
                        help="Output directory for fine-tuned model")
    parser.add_argument("--epochs", type=int, default=20,
                        help="Number of training epochs")
    parser.add_argument("--batch-size", type=int, default=8,
                        help="Training batch size")
    parser.add_argument("--learning-rate", type=float, default=3e-5,
                        help="Learning rate")
    parser.add_argument("--augment", action="store_true",
                        help="Use data augmentation")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device to use (cuda/cpu)")

    args = parser.parse_args()

    logger.info("=" * 60)
    logger.info("Fine-tuning emotion2vec on Portuguese BR datasets")
    logger.info("=" * 60)
    logger.info(f"Base model: {args.base_model}")
    logger.info(f"Device: {args.device}")
    logger.info(f"Epochs: {args.epochs}")
    logger.info(f"Batch size: {args.batch_size}")
    logger.info(f"Data augmentation: {args.augment}")

    # Load datasets
    verbo = load_verbo_dataset()
    emouej = load_emouej_dataset()

    if verbo is None and emouej is None:
        logger.error("โŒ No datasets available. Please download VERBO and/or emoUERJ manually.")
        logger.info("\nDataset sources:")
        logger.info("- VERBO: http://www02.smt.ufrj.br/~verbo/")
        logger.info("- emoUERJ: Contact authors or check university repository")
        return

    # Combine datasets
    datasets = []
    if verbo is not None:
        verbo = normalize_emotion_labels(verbo)
        datasets.append(verbo)
    if emouej is not None:
        emouej = normalize_emotion_labels(emouej)
        datasets.append(emouej)

    combined_dataset = concatenate_datasets(datasets) if len(datasets) > 1 else datasets[0]

    # Cast audio column
    combined_dataset = combined_dataset.cast_column("audio", Audio(sampling_rate=16000))

    # Split into train/validation
    split_dataset = combined_dataset.train_test_split(test_size=0.15, seed=42)
    train_dataset = split_dataset["train"]
    val_dataset = split_dataset["test"]

    logger.info(f"\n๐Ÿ“Š Dataset statistics:")
    logger.info(f"  Training samples: {len(train_dataset)}")
    logger.info(f"  Validation samples: {len(val_dataset)}")

    # Load processor and model
    logger.info(f"\n๐Ÿ”„ Loading base model: {args.base_model}...")
    processor = Wav2Vec2Processor.from_pretrained(args.base_model)
    model = Wav2Vec2ForSequenceClassification.from_pretrained(
        args.base_model,
        num_labels=len(EMOTION_LABELS),
        id2label=ID_TO_LABEL,
        label2id=LABEL_TO_ID
    )

    # Prepare datasets
    logger.info("\n๐Ÿ”„ Preprocessing datasets...")
    train_dataset = train_dataset.map(
        lambda x: prepare_dataset(x, processor, augment=args.augment),
        batched=True,
        remove_columns=train_dataset.column_names
    )
    val_dataset = val_dataset.map(
        lambda x: prepare_dataset(x, processor, augment=False),
        batched=True,
        remove_columns=val_dataset.column_names
    )

    # Training arguments
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    training_args = TrainingArguments(
        output_dir=str(output_dir),
        evaluation_strategy="epoch",
        save_strategy="epoch",
        learning_rate=args.learning_rate,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        num_train_epochs=args.epochs,
        warmup_ratio=0.1,
        logging_steps=10,
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        push_to_hub=False,
        save_total_limit=2,
        fp16=args.device == "cuda",
    )

    # Data collator
    data_collator = DataCollatorWithPadding(processor=processor)

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

    # Train
    logger.info("\n๐Ÿš€ Starting fine-tuning...")
    trainer.train()

    # Evaluate
    logger.info("\n๐Ÿ“Š Final evaluation...")
    metrics = trainer.evaluate()
    logger.info(f"Validation accuracy: {metrics['eval_accuracy']:.4f}")

    # Save model
    logger.info(f"\n๐Ÿ’พ Saving fine-tuned model to {output_dir}...")
    trainer.save_model(str(output_dir))
    processor.save_pretrained(str(output_dir))

    logger.info("\nโœ… Fine-tuning complete!")
    logger.info(f"Model saved to: {output_dir}")
    logger.info("\nTo use this model in the ensemble:")
    logger.info(f"  Emotion2VecModel(model_name='{args.output_dir}', ...)")


if __name__ == "__main__":
    main()