File size: 2,727 Bytes
3fe0726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

from typing import List
import torch
import logging

class BaseModel:
    """
    Base class for all models.
    """

    def __init__(self, model_name: str):
        self.model_name = model_name
        logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        self.logger = logging.getLogger(__name__)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.sentiment_labels = ['NEGATIVE', 'NEUTRAL', 'POSITIVE']

    def predict_sentiment(self, text: str) -> str:
        """
        Predict sentiment for a given text.
        
        Args:
            text: Input text for sentiment analysis
            
        Returns:
            Sentiment label ('NEGATIVE', 'NEUTRAL', or 'POSITIVE')
        """
        try:
            inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
            inputs = {key: val.to(self.device) for key, val in inputs.items()}
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                
            # Get sentiment predictions
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
            sentiment_idx = torch.argmax(probabilities, dim=1).item()
            sentiment = self.sentiment_labels[sentiment_idx]
            
            return sentiment
        
        except Exception as e:
            self.logger.error(f"Error during sentiment prediction: {str(e)}")
            return "NEUTRAL"  # Return neutral as default on error
    
    def batch_predict_sentiment(self, texts: List[str]) -> List[str]:
        """
        Predict sentiment for a batch of texts.
        
        Args:
            texts: List of input texts for sentiment analysis
            
        Returns:
            List of sentiment labels
        """
        try:
            inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
            inputs = {key: val.to(self.device) for key, val in inputs.items()}
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                
            # Get sentiment predictions
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
            sentiment_indices = torch.argmax(probabilities, dim=1).tolist()
            sentiments = [self.sentiment_labels[idx] for idx in sentiment_indices]
            
            return sentiments
        
        except Exception as e:
            self.logger.error(f"Error during batch sentiment prediction: {str(e)}")
            return ["NEUTRAL"] * len(texts)  # Return neutral for all on error