File size: 5,342 Bytes
a43192e 9abf359 a43192e 9abf359 a43192e 9abf359 a43192e 9abf359 a43192e 9abf359 a43192e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from transformers import AutoTokenizer, T5ForConditionalGeneration
import torch
from difflib import SequenceMatcher
class EndpointHandler:
def __init__(self, path=""):
# Load model and tokenizer
model_name = path if path else "grammarly/coedit-large"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = T5ForConditionalGeneration.from_pretrained(model_name)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def paraphrase_batch(self, sentences, num_return_sequences=1, temperature=1.0):
# Add the text editing prefix to each sentence
prefix = "Fix the grammar: "
sentences_with_prefix = [prefix + s for s in sentences]
inputs = self.tokenizer(
sentences_with_prefix,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
).to(self.device)
outputs = self.model.generate(
**inputs,
max_length=512,
num_beams=5,
temperature=temperature,
num_return_sequences=num_return_sequences,
early_stopping=True
)
decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
if num_return_sequences > 1:
grouped = [
decoded[i * num_return_sequences:(i + 1) * num_return_sequences]
for i in range(len(sentences))
]
return grouped
else:
return decoded
def compute_changes(self, original, enhanced):
changes = []
matcher = SequenceMatcher(None, original, enhanced) # char-level, not token-level
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag in ("replace", "insert", "delete"):
original_phrase = original[i1:i2]
new_phrase = enhanced[j1:j2]
changes.append({
"original_phrase": original_phrase,
"new_phrase": new_phrase,
"char_start": i1,
"char_end": i2,
"token_start": None, # not token-based anymore
"token_end": None,
"explanation": f"{tag} change",
"error_type": "whitespace" if original_phrase.isspace() or new_phrase.isspace() else "",
"tip": "Avoid extra spaces between words." if original_phrase.isspace() or new_phrase.isspace() else ""
})
return changes
def __call__(self, inputs):
# This method is the main entry point for the Hugging Face Endpoint.
# Check for both standard and wrapped JSON inputs
if isinstance(inputs, list):
sentences = inputs
parameters = {}
elif isinstance(inputs, dict):
# Check for the common {"inputs": "...", "parameters": {}} format
sentences = inputs.get("inputs", [])
# If inputs is a single string, wrap it in a list
if isinstance(sentences, str):
sentences = [sentences]
parameters = inputs.get("parameters", {})
else:
return {
"success": False,
"error": "Invalid input format. Expected a string, list of strings, or a dictionary with 'inputs' and 'parameters' keys."
}
# Handle optional parameters
num_return_sequences = parameters.get("num_return_sequences", 1)
temperature = parameters.get("temperature", 1.0)
if not sentences:
return {
"success": False,
"error": "No sentences provided."
}
try:
paraphrased = self.paraphrase_batch(sentences, num_return_sequences, temperature)
results = []
if num_return_sequences > 1:
# Logic for multiple return sequences
for i, orig in enumerate(sentences):
for cand in paraphrased[i]:
results.append({
"original_sentence": orig,
"enhanced_sentence": cand,
"changes": self.compute_changes(orig, cand)
})
else:
# Logic for single return sequence
for orig, cand in zip(sentences, paraphrased):
results.append({
"original_sentence": orig,
"enhanced_sentence": cand,
"changes": self.compute_changes(orig, cand)
})
return {
"success": True,
"results": results,
"sentences_count": len(sentences),
"processed_count": len(results),
"skipped_count": 0,
"error_count": 0
}
except Exception as e:
return {
"success": False,
"error": str(e),
"sentences_count": len(sentences),
"processed_count": 0,
"skipped_count": 0,
"error_count": 1
} |