File size: 14,852 Bytes
c3efd49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
"""Voice model wrapper for HuggingFace models."""
import torch
import torch.nn as nn
import logging
from typing import Optional, Iterator, Dict, Any, Tuple
from pathlib import Path
from transformers import AutoModel, AutoConfig, AutoProcessor
import json

from .policy_wrapper import RLVoiceModel

logger = logging.getLogger(__name__)


class VoiceModelWrapper:
    """
    Wrapper for HuggingFace voice models with RL training support.
    
    Provides a consistent interface for model loading, inference,
    checkpointing, and license verification.
    """
    
    # List of known commercial-use licenses
    COMMERCIAL_LICENSES = [
        "apache-2.0",
        "mit",
        "bsd",
        "bsd-3-clause",
        "cc-by-4.0",
        "cc-by-sa-4.0",
        "openrail",
    ]
    
    def __init__(
        self,
        model_name: str,
        device: str = "cuda",
        cache_dir: Optional[str] = None,
        enable_rl: bool = True,
        action_dim: int = 256
    ):
        """
        Initialize the voice model wrapper.

        Args:
            model_name: HuggingFace model identifier
            device: Device to load model on ('cuda', 'cpu', 'mps')
            cache_dir: Optional cache directory for model files
            enable_rl: Whether to add RL policy/value heads
            action_dim: Dimensionality of action space for RL
        """
        self.model_name = model_name
        self.device = device
        self.cache_dir = cache_dir
        self.enable_rl = enable_rl
        self.action_dim = action_dim
        self.model = None
        self.rl_model = None
        self.processor = None
        self.config = None

        logger.info(f"Initialized VoiceModelWrapper for {model_name} on {device} (RL: {enable_rl})")
    
    def load_model(self) -> None:
        """
        Load the voice model from HuggingFace.
        
        Performs license verification and architecture compatibility checks.
        
        Raises:
            ValueError: If model has incompatible license or architecture
            RuntimeError: If model loading fails
        """
        try:
            logger.info(f"Loading model: {self.model_name}")
            
            # Load configuration first
            self.config = AutoConfig.from_pretrained(
                self.model_name,
                cache_dir=self.cache_dir
            )
            
            # Verify license
            self._verify_license()
            
            # Verify architecture compatibility
            self._verify_architecture()
            
            # Load model
            self.model = AutoModel.from_pretrained(
                self.model_name,
                cache_dir=self.cache_dir
            )
            self.model.to(self.device)
            self.model.train()  # Set to training mode for RL

            # Wrap with RL policy/value heads if enabled
            if self.enable_rl:
                hidden_size = self.config.hidden_size if hasattr(self.config, 'hidden_size') else 768
                self.rl_model = RLVoiceModel(
                    base_model=self.model,
                    hidden_size=hidden_size,
                    action_dim=self.action_dim
                )
                self.rl_model.to(self.device)
                logger.info(f"Added RL policy/value heads (action_dim={self.action_dim})")

            # Load processor if available
            try:
                self.processor = AutoProcessor.from_pretrained(
                    self.model_name,
                    cache_dir=self.cache_dir
                )
            except Exception as e:
                logger.warning(f"Could not load processor: {e}")
                self.processor = None

            logger.info(f"Successfully loaded model: {self.model_name}")
            logger.info(f"Model parameters: {self.count_parameters():,}")
            
        except Exception as e:
            error_msg = f"Failed to load model {self.model_name}: {str(e)}"
            logger.error(error_msg)
            raise RuntimeError(error_msg) from e
    
    def _verify_license(self) -> None:
        """
        Verify that the model has a commercial-use license.
        
        Raises:
            ValueError: If license is not suitable for commercial use
        """
        # Try to get license from config
        license_info = getattr(self.config, 'license', None)
        
        if license_info is None:
            logger.warning(
                f"No license information found for {self.model_name}. "
                "Please verify license manually."
            )
            return
        
        license_lower = license_info.lower()
        
        # Check if license is in approved list
        is_commercial = any(
            approved in license_lower 
            for approved in self.COMMERCIAL_LICENSES
        )
        
        if not is_commercial:
            raise ValueError(
                f"Model {self.model_name} has license '{license_info}' "
                f"which may not be suitable for commercial use. "
                f"Approved licenses: {', '.join(self.COMMERCIAL_LICENSES)}"
            )
        
        logger.info(f"License verified: {license_info}")
    
    def _verify_architecture(self) -> None:
        """
        Verify that the model architecture is compatible with RL training.
        
        Checks for required attributes and methods.
        
        Raises:
            ValueError: If architecture is incompatible
        """
        # Check if model has required architecture attributes
        required_attrs = ['config']
        
        for attr in required_attrs:
            if not hasattr(self.config, attr.replace('config.', '')):
                logger.warning(f"Model may be missing attribute: {attr}")
        
        # Check model type
        model_type = getattr(self.config, 'model_type', 'unknown')
        logger.info(f"Model type: {model_type}")
        
        # Verify model can be put in training mode
        if self.model is not None and not hasattr(self.model, 'train'):
            raise ValueError("Model does not support training mode")
        
        logger.info("Architecture compatibility verified")
    
    def generate(
        self,
        input_features: torch.Tensor,
        training: bool = False,
        **kwargs
    ) -> torch.Tensor:
        """
        Generate output from the model.

        Args:
            input_features: Input tensor
            training: If True, compute with gradients (for RL training)
            **kwargs: Additional generation parameters

        Returns:
            Generated output tensor

        Raises:
            RuntimeError: If model is not loaded
        """
        if self.model is None:
            raise RuntimeError("Model not loaded. Call load_model() first.")

        if training:
            # During training, keep gradients for backprop
            outputs = self.model(input_features, **kwargs)
        else:
            # During inference, no gradients needed
            with torch.no_grad():
                outputs = self.model(input_features, **kwargs)

        # Handle different output types
        if hasattr(outputs, 'last_hidden_state'):
            return outputs.last_hidden_state
        elif isinstance(outputs, torch.Tensor):
            return outputs
        else:
            return outputs[0]
    
    def get_logits(self, input_features: torch.Tensor) -> torch.Tensor:
        """
        Get model logits for input features.
        
        Args:
            input_features: Input tensor
        
        Returns:
            Logits tensor
        
        Raises:
            RuntimeError: If model is not loaded
        """
        if self.model is None:
            raise RuntimeError("Model not loaded. Call load_model() first.")
        
        outputs = self.model(input_features)
        
        if hasattr(outputs, 'logits'):
            return outputs.logits
        elif hasattr(outputs, 'last_hidden_state'):
            return outputs.last_hidden_state
        else:
            return outputs[0]
    
    def forward(self, input_features: torch.Tensor, **kwargs) -> Any:
        """
        Forward pass through the model.

        Args:
            input_features: Input tensor
            **kwargs: Additional forward parameters

        Returns:
            Model outputs (RL-compatible if RL enabled)
        """
        if self.model is None:
            raise RuntimeError("Model not loaded. Call load_model() first.")

        # Use RL model if available (returns log_probs, values)
        if self.rl_model is not None:
            return self.rl_model(input_features, **kwargs)
        else:
            return self.model(input_features, **kwargs)

    def sample_action(
        self,
        input_features: torch.Tensor,
        deterministic: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Sample action from the policy (RL training).

        Args:
            input_features: Input audio features
            deterministic: If True, take most likely action

        Returns:
            Tuple of (actions, log_probs, values)

        Raises:
            RuntimeError: If RL model is not enabled
        """
        if self.rl_model is None:
            raise RuntimeError("RL model not enabled. Set enable_rl=True when initializing.")

        return self.rl_model.sample_action(input_features, deterministic)

    def evaluate_actions(
        self,
        input_features: torch.Tensor,
        actions: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Evaluate actions (for PPO training).

        Args:
            input_features: Input audio features
            actions: Actions to evaluate

        Returns:
            Tuple of (log_probs, values, entropy)

        Raises:
            RuntimeError: If RL model is not enabled
        """
        if self.rl_model is None:
            raise RuntimeError("RL model not enabled. Set enable_rl=True when initializing.")

        return self.rl_model.evaluate_actions(input_features, actions)
    
    def save_checkpoint(self, path: str, metadata: Optional[Dict] = None) -> None:
        """
        Save model checkpoint.

        Args:
            path: Path to save checkpoint
            metadata: Optional metadata to save with checkpoint

        Raises:
            RuntimeError: If model is not loaded
        """
        if self.model is None:
            raise RuntimeError("Model not loaded. Call load_model() first.")

        checkpoint_path = Path(path)
        checkpoint_path.parent.mkdir(parents=True, exist_ok=True)

        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'model_name': self.model_name,
            'config': self.config.to_dict() if self.config else None,
            'enable_rl': self.enable_rl,
            'action_dim': self.action_dim,
        }

        # Save RL model state if present
        if self.rl_model is not None:
            checkpoint['rl_model_state_dict'] = self.rl_model.state_dict()

        if metadata:
            checkpoint['metadata'] = metadata

        torch.save(checkpoint, checkpoint_path)
        logger.info(f"Checkpoint saved to {checkpoint_path}")
    
    def load_checkpoint(self, path: str) -> Dict:
        """
        Load model checkpoint.

        Args:
            path: Path to checkpoint file

        Returns:
            Checkpoint metadata

        Raises:
            RuntimeError: If model is not loaded
            FileNotFoundError: If checkpoint file doesn't exist
        """
        if self.model is None:
            raise RuntimeError("Model not loaded. Call load_model() first.")

        checkpoint_path = Path(path)
        if not checkpoint_path.exists():
            raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

        checkpoint = torch.load(checkpoint_path, map_location=self.device)

        self.model.load_state_dict(checkpoint['model_state_dict'])

        # Load RL model state if present
        if 'rl_model_state_dict' in checkpoint and self.rl_model is not None:
            self.rl_model.load_state_dict(checkpoint['rl_model_state_dict'])
            logger.info("Loaded RL model state")

        logger.info(f"Checkpoint loaded from {checkpoint_path}")

        return checkpoint.get('metadata', {})
    
    def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
        """
        Get iterator over trainable parameters.
        
        Returns:
            Iterator over trainable parameters
        
        Raises:
            RuntimeError: If model is not loaded
        """
        if self.model is None:
            raise RuntimeError("Model not loaded. Call load_model() first.")
        
        return (p for p in self.model.parameters() if p.requires_grad)
    
    def count_parameters(self, trainable_only: bool = False) -> int:
        """
        Count model parameters.

        Args:
            trainable_only: If True, count only trainable parameters

        Returns:
            Number of parameters
        """
        if self.model is None:
            return 0

        # Count RL model params if available, otherwise base model
        model_to_count = self.rl_model if self.rl_model is not None else self.model

        if trainable_only:
            return sum(p.numel() for p in model_to_count.parameters() if p.requires_grad)
        else:
            return sum(p.numel() for p in model_to_count.parameters())
    
    def set_training_mode(self, mode: bool = True) -> None:
        """
        Set model training mode.

        Args:
            mode: If True, set to training mode; otherwise evaluation mode
        """
        if self.model is None:
            raise RuntimeError("Model not loaded. Call load_model() first.")

        if mode:
            self.model.train()
            if self.rl_model is not None:
                self.rl_model.train()
        else:
            self.model.eval()
            if self.rl_model is not None:
                self.rl_model.eval()
    
    def to(self, device: str) -> None:
        """
        Move model to specified device.

        Args:
            device: Target device
        """
        if self.model is None:
            raise RuntimeError("Model not loaded. Call load_model() first.")

        self.device = device
        self.model.to(device)
        if self.rl_model is not None:
            self.rl_model.to(device)
        logger.info(f"Model moved to {device}")

    def get_rl_model(self) -> Optional[nn.Module]:
        """
        Get the RL-wrapped model.

        Returns:
            RLVoiceModel if RL is enabled, None otherwise
        """
        return self.rl_model