File size: 5,342 Bytes
a43192e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9abf359
 
a43192e
 
9abf359
 
a43192e
 
 
 
 
9abf359
 
a43192e
9abf359
 
a43192e
 
 
9abf359
a43192e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from transformers import AutoTokenizer, T5ForConditionalGeneration
import torch
from difflib import SequenceMatcher

class EndpointHandler:
    def __init__(self, path=""):
        # Load model and tokenizer
        model_name = path if path else "grammarly/coedit-large"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def paraphrase_batch(self, sentences, num_return_sequences=1, temperature=1.0):
        # Add the text editing prefix to each sentence
        prefix = "Fix the grammar: "
        sentences_with_prefix = [prefix + s for s in sentences]
        
        inputs = self.tokenizer(
            sentences_with_prefix,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        ).to(self.device)
        
        outputs = self.model.generate(
            **inputs,
            max_length=512,
            num_beams=5,
            temperature=temperature,
            num_return_sequences=num_return_sequences,
            early_stopping=True
        )
        
        decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        if num_return_sequences > 1:
            grouped = [
                decoded[i * num_return_sequences:(i + 1) * num_return_sequences]
                for i in range(len(sentences))
            ]
            return grouped
        else:
            return decoded
            
    def compute_changes(self, original, enhanced):
        changes = []
        matcher = SequenceMatcher(None, original, enhanced)  # char-level, not token-level

        for tag, i1, i2, j1, j2 in matcher.get_opcodes():
            if tag in ("replace", "insert", "delete"):
                original_phrase = original[i1:i2]
                new_phrase = enhanced[j1:j2]
                changes.append({
                    "original_phrase": original_phrase,
                    "new_phrase": new_phrase,
                    "char_start": i1,
                    "char_end": i2,
                    "token_start": None,  # not token-based anymore
                    "token_end": None,
                    "explanation": f"{tag} change",
                    "error_type": "whitespace" if original_phrase.isspace() or new_phrase.isspace() else "",
                    "tip": "Avoid extra spaces between words." if original_phrase.isspace() or new_phrase.isspace() else ""
                })
        return changes


    def __call__(self, inputs):
        # This method is the main entry point for the Hugging Face Endpoint.
        
        # Check for both standard and wrapped JSON inputs
        if isinstance(inputs, list):
            sentences = inputs
            parameters = {}
        elif isinstance(inputs, dict):
            # Check for the common {"inputs": "...", "parameters": {}} format
            sentences = inputs.get("inputs", [])
            # If inputs is a single string, wrap it in a list
            if isinstance(sentences, str):
                sentences = [sentences]
            parameters = inputs.get("parameters", {})
        else:
            return {
                "success": False,
                "error": "Invalid input format. Expected a string, list of strings, or a dictionary with 'inputs' and 'parameters' keys."
            }

        # Handle optional parameters
        num_return_sequences = parameters.get("num_return_sequences", 1)
        temperature = parameters.get("temperature", 1.0)

        if not sentences:
            return {
                "success": False,
                "error": "No sentences provided."
            }

        try:
            paraphrased = self.paraphrase_batch(sentences, num_return_sequences, temperature)
            results = []
            
            if num_return_sequences > 1:
                # Logic for multiple return sequences
                for i, orig in enumerate(sentences):
                    for cand in paraphrased[i]:
                        results.append({
                            "original_sentence": orig,
                            "enhanced_sentence": cand,
                            "changes": self.compute_changes(orig, cand)
                        })
            else:
                # Logic for single return sequence
                for orig, cand in zip(sentences, paraphrased):
                    results.append({
                        "original_sentence": orig,
                        "enhanced_sentence": cand,
                        "changes": self.compute_changes(orig, cand)
                    })
            
            return {
                "success": True,
                "results": results,
                "sentences_count": len(sentences),
                "processed_count": len(results),
                "skipped_count": 0,
                "error_count": 0
            }
        except Exception as e:
            return {
                "success": False,
                "error": str(e),
                "sentences_count": len(sentences),
                "processed_count": 0,
                "skipped_count": 0,
                "error_count": 1
            }