from transformers import AutoTokenizer, T5ForConditionalGeneration import torch from difflib import SequenceMatcher class EndpointHandler: def __init__(self, path=""): # Load model and tokenizer model_name = path if path else "grammarly/coedit-large" self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = T5ForConditionalGeneration.from_pretrained(model_name) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) def paraphrase_batch(self, sentences, num_return_sequences=1, temperature=1.0): # Add the text editing prefix to each sentence prefix = "Fix the grammar: " sentences_with_prefix = [prefix + s for s in sentences] inputs = self.tokenizer( sentences_with_prefix, padding=True, truncation=True, max_length=512, return_tensors="pt" ).to(self.device) outputs = self.model.generate( **inputs, max_length=512, num_beams=5, temperature=temperature, num_return_sequences=num_return_sequences, early_stopping=True ) decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) if num_return_sequences > 1: grouped = [ decoded[i * num_return_sequences:(i + 1) * num_return_sequences] for i in range(len(sentences)) ] return grouped else: return decoded def compute_changes(self, original, enhanced): changes = [] matcher = SequenceMatcher(None, original, enhanced) # char-level, not token-level for tag, i1, i2, j1, j2 in matcher.get_opcodes(): if tag in ("replace", "insert", "delete"): original_phrase = original[i1:i2] new_phrase = enhanced[j1:j2] changes.append({ "original_phrase": original_phrase, "new_phrase": new_phrase, "char_start": i1, "char_end": i2, "token_start": None, # not token-based anymore "token_end": None, "explanation": f"{tag} change", "error_type": "whitespace" if original_phrase.isspace() or new_phrase.isspace() else "", "tip": "Avoid extra spaces between words." if original_phrase.isspace() or new_phrase.isspace() else "" }) return changes def __call__(self, inputs): # This method is the main entry point for the Hugging Face Endpoint. # Check for both standard and wrapped JSON inputs if isinstance(inputs, list): sentences = inputs parameters = {} elif isinstance(inputs, dict): # Check for the common {"inputs": "...", "parameters": {}} format sentences = inputs.get("inputs", []) # If inputs is a single string, wrap it in a list if isinstance(sentences, str): sentences = [sentences] parameters = inputs.get("parameters", {}) else: return { "success": False, "error": "Invalid input format. Expected a string, list of strings, or a dictionary with 'inputs' and 'parameters' keys." } # Handle optional parameters num_return_sequences = parameters.get("num_return_sequences", 1) temperature = parameters.get("temperature", 1.0) if not sentences: return { "success": False, "error": "No sentences provided." } try: paraphrased = self.paraphrase_batch(sentences, num_return_sequences, temperature) results = [] if num_return_sequences > 1: # Logic for multiple return sequences for i, orig in enumerate(sentences): for cand in paraphrased[i]: results.append({ "original_sentence": orig, "enhanced_sentence": cand, "changes": self.compute_changes(orig, cand) }) else: # Logic for single return sequence for orig, cand in zip(sentences, paraphrased): results.append({ "original_sentence": orig, "enhanced_sentence": cand, "changes": self.compute_changes(orig, cand) }) return { "success": True, "results": results, "sentences_count": len(sentences), "processed_count": len(results), "skipped_count": 0, "error_count": 0 } except Exception as e: return { "success": False, "error": str(e), "sentences_count": len(sentences), "processed_count": 0, "skipped_count": 0, "error_count": 1 }