Spaces:
Runtime error
Runtime error
| """ | |
| adapted to support pegasus-xsum / local files | |
| """ | |
| # Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Perplexity Metric.""" | |
| import datasets | |
| import numpy as np | |
| import torch | |
| from torch.nn import CrossEntropyLoss | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import getpass | |
| import evaluate | |
| from evaluate import logging | |
| import pdb | |
| WINDOW_SIZE = 3 | |
| def prepare_coh_sents(predictions): | |
| blocks = [] | |
| lens = [] | |
| for pred in predictions: | |
| sents = pred.split("\n") | |
| if len(sents)<=WINDOW_SIZE: | |
| blocks.append(pred) | |
| lens.append(1) | |
| else: | |
| _block = [] | |
| for i in range(0,len(sents)-WINDOW_SIZE+1): | |
| _block.append("\n".join(sents[i:i+WINDOW_SIZE])) | |
| lens.append(len(_block)) | |
| blocks.extend(_block) | |
| # | |
| return blocks,lens | |
| _CITATION = """\ | |
| """ | |
| _DESCRIPTION = """ | |
| Perplexity (PPL) is one of the most common metrics for evaluating language models. | |
| It is defined as the exponentiated average negative log-likelihood of a sequence, calculated with exponent base `e`. | |
| For more information, see https://huggingface.co/docs/transformers/perplexity | |
| """ | |
| _KWARGS_DESCRIPTION = """ | |
| Args: | |
| model_id (str): model used for calculating Perplexity | |
| NOTE: Perplexity can only be calculated for causal language models. | |
| This includes models such as gpt2, causal variations of bert, | |
| causal versions of t5, and more (the full list can be found | |
| in the AutoModelForCausalLM documentation here: | |
| https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM ) | |
| predictions (list of str): input text, each separate text snippet | |
| is one list entry. | |
| batch_size (int): the batch size to run texts through the model. Defaults to 16. | |
| add_start_token (bool): whether to add the start token to the texts, | |
| so the perplexity can include the probability of the first word. Defaults to True. | |
| device (str): device to run on, defaults to 'cuda' when available | |
| Returns: | |
| perplexity: dictionary containing the perplexity scores for the texts | |
| in the input list, as well as the mean perplexity. If one of the input texts is | |
| longer than the max input length of the model, then it is truncated to the | |
| max length for the perplexity computation. | |
| Examples: | |
| Example 1: | |
| >>> perplexity = evaluate.load("perplexity", module_type="metric") | |
| >>> input_texts = ["lorem ipsum", "Happy Birthday!", "Bienvenue"] | |
| >>> results = perplexity.compute(model_id='gpt2', | |
| ... add_start_token=False, | |
| ... predictions=input_texts) # doctest:+ELLIPSIS | |
| >>> print(list(results.keys())) | |
| ['perplexities', 'mean_perplexity'] | |
| >>> print(round(results["mean_perplexity"], 0)) | |
| 647.0 | |
| >>> print(round(results["perplexities"][0], 0)) | |
| 32.0 | |
| Example 2: | |
| >>> from datasets import load_dataset | |
| >>> perplexity = evaluate.load("perplexity", module_type="metric") | |
| >>> input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:10] # doctest: +SKIP | |
| >>> input_texts = [s for s in input_texts if s!=''] | |
| >>> results = perplexity.compute(model_id='gpt2', | |
| ... predictions=input_texts) | |
| >>> print(list(results.keys())) | |
| ['perplexities', 'mean_perplexity'] | |
| >>> print(round(results["mean_perplexity"], 2)) # doctest: +SKIP | |
| 576.76 | |
| >>> print(round(results["perplexities"][0], 2)) # doctest: +SKIP | |
| 889.28 | |
| """ | |
| class LocalCohPPL(evaluate.Measurement): | |
| def _info(self): | |
| return evaluate.MetricInfo( | |
| module_type="measurement", | |
| description=_DESCRIPTION, | |
| citation=_CITATION, | |
| inputs_description=_KWARGS_DESCRIPTION, | |
| features=datasets.Features( | |
| { | |
| "predictions": datasets.Value("string"), | |
| } | |
| ), | |
| reference_urls=["https://huggingface.co/spaces/ronaldahmed/local_coh_ppl"], | |
| ) | |
| ## PEDICTIONS: [str] sentences joined by "\n" | |
| def _compute(self, predictions, model_id, batch_size: int = 16, add_start_token: bool = True, device=None): | |
| MODEL_CACHE_DIR = "/home/rcardena/.cache/huggingface/" | |
| if getpass.getuser() == "s1987051": | |
| MODEL_CACHE_DIR="/disk/ocean/rcardenas/tools/huggingface/" | |
| elif getpass.getuser() == "rcardena": | |
| MODEL_CACHE_DIR="/gfs/team/nlp/users/rcardena/tools/huggingface/" | |
| if device is not None: | |
| assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu." | |
| if device == "gpu": | |
| device = "cuda" | |
| else: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = AutoModelForCausalLM.from_pretrained(model_id,cache_dir=MODEL_CACHE_DIR) | |
| model = model.to(device) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| cache_dir=MODEL_CACHE_DIR, | |
| use_fast="cnn_dailymail" not in model_id, | |
| ) | |
| # if batch_size > 1 (which generally leads to padding being required), and | |
| # if there is not an already assigned pad_token, assign an existing | |
| # special token to also be the padding token | |
| if tokenizer.pad_token is None and batch_size > 1: | |
| existing_special_tokens = list(tokenizer.special_tokens_map_extended.values()) | |
| # check that the model already has at least one special token defined | |
| assert ( | |
| len(existing_special_tokens) > 0 | |
| ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1." | |
| # assign one of the special tokens to also be the pad token | |
| tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]}) | |
| model.config.max_length = 512 if "scibert" in model_id else model.config.max_length | |
| if add_start_token: | |
| # leave room for <BOS> token to be added: | |
| assert ( | |
| tokenizer.bos_token is not None | |
| ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False" | |
| max_tokenized_len = model.config.max_length - 1 | |
| else: | |
| max_tokenized_len = model.config.max_length | |
| loss_fct = CrossEntropyLoss(reduction="none") | |
| blocks,blens = prepare_coh_sents(predictions) | |
| all_norm_ppl = [] | |
| for start_index in logging.tqdm(range(0, len(blocks), batch_size)): | |
| end_index = min(start_index + batch_size, len(blocks)) | |
| batch_sents = blocks[start_index:end_index] | |
| encodings = tokenizer( | |
| batch_sents, | |
| add_special_tokens=False, | |
| padding=True, | |
| truncation=True, | |
| max_length=max_tokenized_len, | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ).to(device) | |
| encoded_texts = encodings["input_ids"] | |
| attn_masks = encodings["attention_mask"] | |
| # check that each input is long enough: | |
| if add_start_token: | |
| assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long." | |
| else: | |
| assert torch.all( | |
| torch.ge(attn_masks.sum(1), 2) | |
| ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings." | |
| if add_start_token: | |
| bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_texts.size(dim=0)).to(device) | |
| encoded_texts = torch.cat([bos_tokens_tensor, encoded_texts], dim=1) | |
| attn_masks = torch.cat( | |
| [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_masks], dim=1 | |
| ) | |
| # tokenize by sentence | |
| sent_tok_lens = [] | |
| for pred in batch_sents: | |
| ss = pred.split("\n") | |
| sslens = [len(tokenizer(y,add_special_tokens=False,padding=False).input_ids) for y in ss] | |
| offset = int(add_start_token) | |
| sspos = [offset] | |
| for sslen in sslens: | |
| offset = min(offset + sslen,511 + int(add_start_token)) | |
| if offset == sspos[-1]: break # reached length limit | |
| sspos.append(offset) | |
| sent_tok_lens.append(sspos) | |
| labels = encoded_texts | |
| with torch.no_grad(): | |
| out_logits = model(encoded_texts, attention_mask=attn_masks).logits | |
| shift_logits = out_logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| shift_attention_mask_batch = attn_masks[..., 1:].contiguous() | |
| loss_out = loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch | |
| perplexity_all = torch.exp( | |
| loss_out.sum(1) | |
| / shift_attention_mask_batch.sum(1) | |
| ).detach().cpu().numpy().tolist() | |
| norm_ppl = [] | |
| for b,stl in enumerate(sent_tok_lens): | |
| indv = [] | |
| for i in range(1,len(stl)): | |
| ppl = torch.exp( loss_out[b,stl[i-1]:stl[i]].sum() / shift_attention_mask_batch[b,stl[i-1]:stl[i]].sum() ).detach().cpu().item() | |
| indv.append(ppl) | |
| nppl = perplexity_all[b] / sum(indv) if len(indv)>1 else 0.0 | |
| norm_ppl.append(nppl) | |
| # | |
| all_norm_ppl.extend(norm_ppl) | |
| if any(np.isnan(norm_ppl)): | |
| print("[compute ppl] nan ...") | |
| pdb.set_trace() | |
| print(">>") | |
| # | |
| avg_ppl = [] | |
| offset = 0 | |
| for _len in blens: | |
| avg_ppl.append( float(np.mean(all_norm_ppl[offset:offset+_len])) ) | |
| offset += _len | |
| if any(np.isnan(avg_ppl)): | |
| print("[compute ppl] nan ...") | |
| pdb.set_trace() | |
| print(">>") | |
| return {"local_coh_ppl": avg_ppl} | |