Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from typing import List, Tuple | |
| import numpy as np | |
| # Load model directly | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import torch | |
| from accelerate import Accelerator | |
| accelerator = Accelerator() | |
| tokenizer = AutoTokenizer.from_pretrained("under-tree/transformer-en-ru") | |
| model = AutoModelForSeq2SeqLM.from_pretrained("under-tree/transformer-en-ru") | |
| device = accelerator.device | |
| model = accelerator.prepare(model) | |
| class TranslationResult: | |
| input_text: str | |
| n_input: int | |
| input_tokens: List[str] | |
| n_output: int | |
| output_text: str | |
| output_tokens: List[str] | |
| output_scores: List[List[Tuple[str, float]]] | |
| cross_attention: np.ndarray | |
| def translator_fn(input_text: str, k=10) -> TranslationResult: | |
| # Preprocess input | |
| inputs = tokenizer(input_text, return_tensors="pt").to(device) | |
| input_tokens = tokenizer.batch_decode(inputs.input_ids[0]) | |
| input_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in input_tokens]).to(device) | |
| # Generate output | |
| outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, output_attentions=True) | |
| output_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) | |
| output_tokens = tokenizer.batch_decode(outputs.sequences[0]) | |
| output_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in output_tokens]).to(device) | |
| # Get cross attention matrix | |
| cross_attention = torch.stack([torch.stack(t) for t in outputs.cross_attentions]) | |
| attention_matrix = cross_attention.mean(dim=4).mean(dim=3).mean(dim=2).mean(dim=1).detach().cpu().numpy() | |
| # Get top tokens | |
| top_scores = [] | |
| len_input = len(input_tokens) | |
| len_output = len(output_tokens) | |
| for i in range(len_output - 1): | |
| if i + 1 < len_output and output_special_mask[i + 1] == 1: | |
| # Skip special tokens (e.g. </s>, <pad>, etc.) | |
| continue | |
| top_elements, top_indices = outputs.scores[i].mean(dim=0).topk(k) | |
| top_elements = top_elements.exp() | |
| top_elements /= top_elements.sum() | |
| top_indices = tokenizer.batch_decode(top_indices) | |
| # filter out special tokens | |
| top_pairs = [(m, t.item()) for t, m in zip(top_elements, top_indices) if m not in tokenizer.all_special_tokens] | |
| top_scores.append(top_pairs) | |
| # Filter out special tokens from all elements | |
| clean_output_tokens = [t for t, m in zip(output_tokens, output_special_mask) if m == 0] | |
| clean_input_tokens = [t for t, m in zip(input_tokens, input_special_mask) if m == 0] | |
| clean_attention_matrix = attention_matrix[:len_output, :len_input] # for padding | |
| clean_attention_matrix = np.delete(clean_attention_matrix, np.where(output_special_mask.detach().cpu().numpy() == 1), axis=0) | |
| clean_attention_matrix = np.delete(clean_attention_matrix, np.where(input_special_mask.detach().cpu().numpy() == 1), axis=1) | |
| n_input = len(clean_input_tokens) | |
| n_output = len(clean_output_tokens) | |
| assert clean_attention_matrix.shape == (n_output, n_input) | |
| assert len(top_scores) == n_output | |
| return TranslationResult( | |
| input_text=input_text, | |
| n_input=n_input, | |
| input_tokens=clean_input_tokens, | |
| output_text=output_text, | |
| n_output=n_output, | |
| output_tokens=clean_output_tokens, | |
| output_scores=top_scores, | |
| cross_attention=clean_attention_matrix | |
| ) | |