Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import os | |
| import time | |
| from datetime import datetime | |
| import logging | |
| from pathlib import Path | |
| import requests | |
| import json | |
| import numpy as np | |
| import pandas as pd | |
| import spacy | |
| import litellm | |
| from tqdm import tqdm | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForTokenClassification, AutoConfig, Qwen2VLForConditionalGeneration, AutoProcessor | |
| from peft import PeftModel | |
| import torch | |
| import cohere | |
| from openai import OpenAI | |
| from together import Together | |
| import anthropic | |
| import replicate | |
| # import google.generativeai as genai | |
| import vertexai | |
| from vertexai.generative_models import GenerativeModel, Part, SafetySetting, FinishReason | |
| from mistralai import Mistral | |
| from qwen_vl_utils import process_vision_info | |
| import src.backend.util as util | |
| import src.envs as envs | |
| litellm.set_verbose=True | |
| # Set up basic configuration for logging | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s') | |
| # Load spacy model for word tokenization | |
| nlp = spacy.load("en_core_web_sm") | |
| os.environ["HUGGINGFACE_API_KEY"] = envs.TOKEN | |
| class ModelLoadingException(Exception): | |
| """Exception raised for errors in loading a model. | |
| Attributes: | |
| model_id (str): The model identifier. | |
| revision (str): The model revision. | |
| """ | |
| def __init__(self, model_id, revision, messages="Error initializing model"): | |
| self.model_id = model_id | |
| self.revision = revision | |
| super().__init__(f"{messages} id={model_id} revision={revision}") | |
| class SummaryGenerator: | |
| """A class to generate summaries using a causal language model. | |
| Attributes: | |
| model (str): huggingface/{model_id} | |
| api_base (str): https://api-inference.huggingface.co/models/{model_id} | |
| summaries_df (DataFrame): DataFrame to store generated summaries. | |
| revision (str): Model revision. | |
| avg_length (float): Average length of summaries. | |
| answer_rate (float): Rate of non-empty summaries. | |
| """ | |
| def __init__(self, model_id, revision, device): | |
| """ | |
| Initializes the SummaryGenerator with a model. | |
| Args: | |
| model_id (str): Identifier for the model. | |
| revision (str): Revision of the model. | |
| """ | |
| self.model_id = model_id | |
| self.model = f"huggingface/{model_id}" | |
| self.api_base = f"https://api-inference.huggingface.co/models/{model_id}" | |
| self.summaries_df = pd.DataFrame() | |
| self.revision = revision | |
| self.device = device | |
| self.avg_length = None | |
| self.answer_rate = None | |
| self.exceptions = None | |
| self.local_model = None | |
| self.local_pipeline = None | |
| def generate_summaries(self, df, save_path=None): | |
| """Generate summaries for a given DataFrame of source docs. | |
| Args: | |
| df (DataFrame): DataFrame containing source docs. | |
| Returns: | |
| summaries_df (DataFrame): Generated summaries by the model. | |
| """ | |
| exceptions = [] | |
| if (save_path is not None) and os.path.exists(save_path): | |
| self.summaries_df = pd.read_csv(save_path) | |
| print(f'Loaded generated summaries from {save_path}') | |
| else: | |
| source, summary, dataset = [], [], [] | |
| print(f"Total: {df.shape[0]}") | |
| for index, row in tqdm(df.iterrows(), total=df.shape[0]): | |
| _source = row['text'] | |
| _dataset = row['dataset'] | |
| system_prompt = envs.SYSTEM_PROMPT | |
| user_prompt = f"{envs.USER_PROMPT}\nPassage:\n{_source}" | |
| _summary = None | |
| while not _summary: | |
| try: | |
| _summary = self.generate_summary(system_prompt, user_prompt) | |
| # print(f"Finish index {index}") | |
| break | |
| except Exception as e: | |
| if 'Rate limit reached' in str(e): | |
| wait_time = 300 | |
| current_time = datetime.now().strftime('%H:%M:%S') | |
| print(f"Rate limit hit at {current_time}. Waiting for 5 minutes before retrying...") | |
| time.sleep(wait_time) | |
| elif 'is currently loading' in str(e): | |
| wait_time = 200 | |
| print(f"Model is loading, wait for {wait_time}") | |
| time.sleep(wait_time) | |
| elif '429' in str(e): # for gemini models | |
| wait_time = 60 | |
| print(f"Quota has reached, wait for {wait_time}") | |
| time.sleep(wait_time) | |
| else: | |
| print(f"Error at index {index}: {e}") | |
| _summary = "" | |
| exceptions.append(index) | |
| break | |
| summary.append(_summary) | |
| source.append(_source) | |
| dataset.append(_dataset) | |
| # Sleep to prevent hitting rate limits too frequently | |
| time.sleep(1) | |
| self.summaries_df = pd.DataFrame(list(zip(source, summary, dataset)), | |
| columns=["source", "summary", "dataset"]) | |
| if save_path is not None: | |
| print(f'Save summaries to {save_path}') | |
| fpath = Path(save_path) | |
| fpath.parent.mkdir(parents=True, exist_ok=True) | |
| self.summaries_df.to_csv(fpath) | |
| self.exceptions = exceptions | |
| self._compute_avg_length() | |
| self._compute_answer_rate() | |
| return self.summaries_df | |
| def generate_summary(self, system_prompt: str, user_prompt: str): | |
| # Using Together AI API | |
| using_together_api = False | |
| together_ai_api_models = ['mixtral', 'dbrx', 'wizardlm', 'llama-3-', 'qwen2-72b-instruct', 'zero-one-ai', 'llama-3.2-'] #, 'mistralai' | |
| using_replicate_api = False | |
| replicate_api_models = ['snowflake', 'llama-3.1-405b'] | |
| using_pipeline = False | |
| pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b', 'phi-3.5', 'mistral-nemo', 'llama-3.3', 'phi-4'] | |
| for replicate_api_model in replicate_api_models: | |
| if replicate_api_model in self.model_id.lower(): | |
| using_replicate_api = True | |
| break | |
| if not using_replicate_api: | |
| for together_ai_api_model in together_ai_api_models: | |
| if together_ai_api_model in self.model_id.lower(): | |
| using_together_api = True | |
| break | |
| if not using_replicate_api and not using_together_api: | |
| for pipeline_model in pipeline_models: | |
| if pipeline_model in self.model_id.lower(): | |
| using_pipeline = True | |
| break | |
| # if 'mixtral' in self.model_id.lower() or 'dbrx' in self.model_id.lower() or 'wizardlm' in self.model_id.lower(): # For mixtral and dbrx models, use Together AI API | |
| if using_together_api: | |
| print('using together api') | |
| client = Together(api_key=os.environ.get('TOGETHER_API_KEY')) | |
| if 'llama-3.2-90b-vision' in self.model_id.lower() or 'llama-3.2-11b-vision' in self.model_id.lower(): | |
| messages = [ | |
| {"role": "system","content": system_prompt}, | |
| {"role": "user","content": [{"type": "text","text": user_prompt}]} | |
| ] | |
| else: | |
| messages = [{"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}] | |
| response = client.chat.completions.create( | |
| model=self.model_id, | |
| messages = messages, | |
| max_tokens=250, | |
| temperature=0, | |
| ) | |
| # print(response) | |
| result = response.choices[0].message.content | |
| print(result) | |
| return result | |
| # Using OpenAI API | |
| elif 'openai' in self.model_id.lower(): | |
| client = OpenAI() | |
| response = client.chat.completions.create( | |
| model=self.model_id.replace('openai/',''), | |
| messages=[{"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}] if 'gpt' in self.model_id | |
| else [{"role": "user", "content": system_prompt + '\n' + user_prompt}], | |
| temperature=0.0 if 'gpt' in self.model_id.lower() else 1.0, # fixed at 1 for o1 models | |
| # max_completion_tokens=250 if 'gpt' in self.model_id.lower() else None, # not compatible with o1 series models | |
| ) | |
| # print(response) | |
| result = response.choices[0].message.content | |
| print(result) | |
| return result | |
| # Using Grok API | |
| elif 'grok' in self.model_id.lower(): # xai | |
| XAI_API_KEY = os.getenv("XAI_API_KEY") | |
| client = OpenAI( | |
| api_key=XAI_API_KEY, | |
| base_url="https://api.x.ai/v1", | |
| ) | |
| completion = client.chat.completions.create( | |
| model=self.model_id.split('/')[-1], | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=0.0 | |
| ) | |
| result = completion.choices[0].message.content | |
| print(result) | |
| return result | |
| # Using Vertex AI API for Gemini models | |
| elif 'gemini' in self.model_id.lower(): | |
| vertexai.init(project=os.getenv("GOOGLE_PROJECT_ID"), location="us-central1") | |
| model = GenerativeModel( | |
| self.model_id.lower().split('google/')[-1], | |
| system_instruction = [system_prompt] | |
| ) | |
| generation_config = { | |
| "temperature": 0, | |
| "max_output_tokens": 500 | |
| } | |
| safety_settings = [ | |
| SafetySetting( | |
| category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH, | |
| threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE | |
| ), | |
| SafetySetting( | |
| category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, | |
| threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE | |
| ), | |
| SafetySetting( | |
| category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, | |
| threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE | |
| ), | |
| SafetySetting( | |
| category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT, | |
| threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE | |
| ) | |
| ] | |
| response = model.generate_content( | |
| user_prompt, | |
| safety_settings=safety_settings, | |
| generation_config=generation_config | |
| ) | |
| result = response.text | |
| print(result) | |
| return result | |
| # Using Replicate API | |
| elif using_replicate_api: | |
| print("using replicate") | |
| if 'snowflake' in self.model_id.lower(): | |
| input = { | |
| "prompt": user_prompt, | |
| "temperature": 0, | |
| "max_new_tokens": 250, | |
| "stop_sequences": "<|im_end|>", | |
| "prompt_template": f"<|im_start|>system\n{system_prompt}<|im_end|>\n" + "<|im_start|>user\n{prompt}<|im_end|>\n\n<|im_start|>assistant\n", | |
| } | |
| else: | |
| input = { | |
| "prompt": user_prompt, | |
| "system_prompt": system_prompt, | |
| "temperature": 0, | |
| "max_new_tokens": 250 | |
| } | |
| response = replicate.run( | |
| self.model_id, | |
| input=input | |
| ) | |
| # print(response) | |
| if isinstance(response, list): | |
| response = ''.join(response) | |
| # print(response) | |
| # print() | |
| print(response) | |
| return response | |
| # Using Anthropic API for Claude models | |
| elif 'claude' in self.model_id.lower(): # using anthropic api | |
| print('using Anthropic API') | |
| client = anthropic.Anthropic() | |
| message = client.messages.create( | |
| model=self.model_id.split('/')[-1], | |
| max_tokens=1024, | |
| temperature=0, | |
| system=system_prompt, | |
| messages=[ | |
| { | |
| "role": "user", | |
| # "content": [ | |
| # { | |
| # "type": "text", | |
| # "text": user_prompt | |
| # } | |
| # ] | |
| "content": user_prompt | |
| } | |
| ] | |
| ) | |
| result = message.content[0].text | |
| print(result) | |
| return result | |
| # Using Cohere API | |
| elif 'command-r' in self.model_id.lower() or 'aya-expanse' in self.model_id.lower(): | |
| co = cohere.ClientV2(os.getenv('COHERE_API_TOKEN')) | |
| response = co.chat( | |
| model=self.model_id.split('/')[-1], | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| temperature=0, | |
| ) | |
| result = response.message.content[0].text | |
| print(result) | |
| return result | |
| # Using MistralAI API | |
| elif 'mistral-large' in self.model_id.lower(): | |
| api_key = os.environ["MISTRAL_API_KEY"] | |
| client = Mistral(api_key=api_key) | |
| messages = [ | |
| { | |
| "role":"system", | |
| "content":system_prompt | |
| }, | |
| { | |
| "role":"user", | |
| "content":user_prompt | |
| } | |
| ] | |
| # No streaming | |
| chat_response = client.chat.complete( | |
| model=self.model_id, | |
| messages=messages, | |
| ) | |
| result = chat_response.choices[0].message.content | |
| print(result) | |
| return result | |
| # Using Deepseek API | |
| elif 'deepseek' in self.model_id.lower(): | |
| client = OpenAI(api_key=os.getenv("DeepSeek_API_KEY"), base_url="https://api.deepseek.com") | |
| response = client.chat.completions.create( | |
| model=self.model_id.split('/')[-1], | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| max_tokens=250, | |
| temperature=0, | |
| stream=False | |
| ) | |
| result = response.choices[0].message.content | |
| print(result) | |
| return result | |
| # Using HF pipeline or local checkpoints | |
| elif self.local_model is None and self.local_pipeline is None: | |
| if using_pipeline: | |
| self.local_pipeline = pipeline( | |
| "text-generation", | |
| model=self.model_id, | |
| tokenizer=AutoTokenizer.from_pretrained(self.model_id), | |
| torch_dtype=torch.bfloat16 if 'llama-3.2' in self.model_id.lower() or 'llama-3.3' in self.model_id.lower() else "auto", | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| else: | |
| if 'ragamuffin' in self.model_id.lower(): | |
| self.tokenizer = AutoTokenizer.from_pretrained(os.path.join('/home/miaoran', self.model_id)) | |
| else: | |
| self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True) | |
| print("Tokenizer loaded") | |
| if 'jamba' in self.model_id.lower(): | |
| self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="flash_attention_2", | |
| device_map="auto", | |
| use_mamba_kernels=False) | |
| elif 'qwen2-vl' in self.model_id.lower(): | |
| self.local_model = Qwen2VLForConditionalGeneration.from_pretrained( | |
| self.model_id, torch_dtype="auto", device_map="auto" | |
| ) | |
| self.processor = AutoProcessor.from_pretrained(self.model_id) | |
| # elif 'ragamuffin' in self.model_id.lower(): | |
| # print('Using ragamuffin') | |
| # self.local_model = AutoModelForCausalLM.from_pretrained(os.path.join('/home/miaoran', self.model_id), | |
| # torch_dtype=torch.bfloat16, # forcing bfloat16 for now | |
| # attn_implementation="flash_attention_2") | |
| elif 'olmo' in self.model_id.lower(): | |
| self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id)#torch_dtype="auto" | |
| elif 'qwq-' in self.model_id.lower(): | |
| self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, torch_dtype="auto", device_map="auto") | |
| else: | |
| self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto")#torch_dtype="auto" | |
| # print(self.local_model.device) | |
| print("Local model loaded") | |
| # Using local model/pipeline | |
| if self.local_pipeline: | |
| print('Using Transformers pipeline') | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| outputs = self.local_pipeline( | |
| messages, | |
| max_new_tokens=256, | |
| # return_full_text=False, | |
| do_sample=False | |
| ) | |
| result = outputs[0]["generated_text"][-1]['content'] | |
| print(result) | |
| return result | |
| elif self.local_model: # cannot call API. using local model / pipeline | |
| print('Using local model') | |
| # Set appropriate prompt based on model document | |
| if 'gemma' in self.model_id.lower() or 'mistral-7b' in self.model_id.lower(): | |
| messages=[ | |
| # gemma-1.1, mistral-7b does not accept system role | |
| {"role": "user", "content": system_prompt + '\n' + user_prompt} | |
| ] | |
| prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | |
| elif 'phi-2' in self.model_id.lower(): | |
| prompt = system_prompt + '\n' + user_prompt | |
| elif 'intel' in self.model_id.lower(): | |
| prompt = f"### System:\n{system_prompt}\n### User:\n{user_prompt}\n### Assistant:\n" | |
| elif 'qwen2-vl' in self.model_id.lower(): | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": [ | |
| {"type": "text", "text": system_prompt} | |
| ] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": user_prompt}, | |
| ], | |
| } | |
| ] | |
| else: | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | |
| # Tokenize inputs | |
| if 'olmo' in self.model_id.lower(): | |
| input_ids = self.tokenizer([prompt], return_tensors='pt', return_token_type_ids=False)#.to(self.device) | |
| elif 'qwq' in self.model_id.lower(): | |
| input_ids = self.tokenizer([prompt], return_tensors="pt").to(self.device) | |
| else: | |
| input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| # Generate outputs | |
| if 'granite' in self.model_id.lower(): | |
| self.local_model.eval() | |
| outputs = self.local_model.generate(**input_ids, max_new_tokens=250) | |
| elif 'olmo' in self.model_id.lower(): | |
| outputs = self.local_model.generate(**input_ids, max_new_tokens=250, do_sample=True, temperature=0.01)#top_k=50, top_p=0.95) | |
| elif 'qwq' in self.model_id.lower(): | |
| outputs = self.local_model.generate(**input_ids, max_new_tokens=512, do_sample=True, temperature=0.01) | |
| else: | |
| with torch.no_grad(): | |
| outputs = self.local_model.generate(**input_ids, do_sample=True, max_new_tokens=250, temperature=0.01)#, pad_token_id=self.tokenizer.eos_token_id | |
| if 'glm' in self.model_id.lower() or 'ragamuffin' in self.model_id.lower() or 'granite' in self.model_id.lower(): | |
| outputs = outputs[:, input_ids['input_ids'].shape[1]:] | |
| elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower() or 'qwq-' in self.model_id.lower(): | |
| outputs = [ | |
| out_ids[len(in_ids) :] for in_ids, out_ids in zip(input_ids.input_ids, outputs) | |
| ] | |
| # Decode outputs | |
| if 'qwen2-vl' in self.model_id.lower(): | |
| result = self.processor.batch_decode( | |
| outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| )[0] | |
| elif 'olmo' in self.model_id.lower() or 'qwq' in self.model_id.lower(): | |
| result = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
| else: | |
| result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| if 'gemma-2' in self.model_id.lower(): | |
| result = result.split(user_prompt + '\nmodel')[-1].strip() | |
| elif 'intel' in self.model_id.lower(): | |
| result = result.split("### Assistant:\n")[-1] | |
| elif 'jamba' in self.model_id.lower(): | |
| result = result.split(messages[-1]['content'])[1].strip() | |
| elif 'qwen2-vl' in self.model_id.lower() or 'qwen2.5' in self.model_id.lower(): | |
| pass | |
| elif 'olmo' in self.model_id.lower(): | |
| result = result.split("<|assistant|>\n")[-1] | |
| else: | |
| result = result.replace(prompt.strip(), '') | |
| print(result) | |
| return result | |
| def _compute_avg_length(self): | |
| """ | |
| Compute the average length of non-empty summaries using SpaCy. | |
| """ | |
| total_word_count = 0 | |
| total_count = 0 | |
| for summary in self.summaries_df['summary']: | |
| if util.is_summary_valid(summary): | |
| doc = nlp(summary) | |
| words = [token.text for token in doc if token.is_alpha] | |
| total_word_count += len(words) | |
| total_count += 1 | |
| self.avg_length = 0 if total_count == 0 else total_word_count / total_count | |
| def _compute_answer_rate(self): | |
| """ | |
| Compute the rate of non-empty summaries. | |
| """ | |
| valid_count = sum(1 for summary in self.summaries_df['summary'] | |
| if util.is_summary_valid(summary)) | |
| total_count = len(self.summaries_df) | |
| self.answer_rate = 0 if total_count == 0 else valid_count / total_count | |
| class EvaluationModel: | |
| """A class to evaluate generated summaries. | |
| Attributes: | |
| model (CrossEncoder): The evaluation model. | |
| scores (list): List of evaluation scores. | |
| accuracy (float): Accuracy of the summaries. | |
| hallucination_rate (float): Rate of hallucination in summaries. | |
| """ | |
| def __init__(self, model_path, device): | |
| """ | |
| Initializes the EvaluationModel with a CrossEncoder model. | |
| Args: | |
| model_path (str): Path to the CrossEncoder model. | |
| """ | |
| config = AutoConfig.from_pretrained('google/flan-t5-large') | |
| self.model = AutoModelForTokenClassification.from_pretrained(model_path, config=config) | |
| self.device = device | |
| self.model.to(self.device) | |
| self.scores = [] | |
| self.factual_consistency_rate = None | |
| self.hallucination_rate = None | |
| def predict(self, text_pairs): | |
| """Load LoRA adapters of HHEM and make predictions | |
| All HHEM 2.1 settings, e.g., prompt template, are hardcoded in this function. | |
| Args: | |
| text_pairs: list of tuples, each tuple contains two strings (premise, hypothesis) | |
| checkpoint: model ID on Hugging Face | |
| """ | |
| prompt = "<pad> Determine if the hypothesis is true given the premise?\n\nPremise: {text1}\n\nHypothesis: {text2}" | |
| tokenizer = AutoTokenizer.from_pretrained('t5-base') | |
| inputs = tokenizer( | |
| [prompt.format(text1=pair[0], text2=pair[1]) for pair in text_pairs], | |
| return_tensors='pt', padding='longest').to(self.device) | |
| self.model.eval() | |
| with torch.no_grad(): | |
| output = self.model(**inputs) | |
| logits = output.logits | |
| logits = logits[:,0,:] # get the logits on the first token | |
| logits = torch.softmax(logits, dim=-1) | |
| scores = [round(x, 5) for x in logits[:, 1].tolist()] # list of float | |
| return scores | |
| def evaluate_hallucination(self, summaries_df): | |
| """ | |
| Evaluate the hallucination rate in summaries. Updates the 'scores' attribute | |
| of the instance with the computed scores. | |
| Args: | |
| summaries_df (DataFrame): DataFrame containing source docs and summaries. | |
| Returns: | |
| list: List of hallucination scores. Also updates the 'scores' attribute of the instance. | |
| """ | |
| hem_scores = [] | |
| sources = [] | |
| summaries = [] | |
| source_summary_pairs = util.create_pairs(summaries_df) | |
| for doc, summary in source_summary_pairs: | |
| if util.is_summary_valid(summary): | |
| try: | |
| summary = util.normalize_summary(summary) | |
| score = self.predict([(doc, summary)])[0] | |
| hem_scores.append(score) | |
| sources.append(doc) | |
| summaries.append(summary) | |
| if score < 0.5: | |
| print(score) | |
| print(doc) | |
| print('-'*20) | |
| print(summary) | |
| print('='*50) | |
| except Exception as e: | |
| logging.error(f"Error while running HEM: {e}") | |
| raise | |
| self.scores = hem_scores | |
| eval_results = {'source': sources, 'summary': summaries, 'HEM scores': hem_scores} | |
| return hem_scores, eval_results | |
| def compute_factual_consistency_rate(self, threshold=0.5): | |
| """ | |
| Compute the factual consistency rate of the evaluated summaries based on | |
| the previously calculated scores. This method relies on the 'scores' | |
| attribute being populated, typically via the 'evaluate_hallucination' method. | |
| Returns: | |
| float: Factual Consistency Rate. Also updates the 'factual_consistency_rate' | |
| and 'hallucination_rate' attributes of the instance. | |
| Raises: | |
| ValueError: If scores have not been calculated prior to calling this method. | |
| """ | |
| if not self.scores: | |
| error_msg = "Scores not calculated. Call evaluate_hallucination() first." | |
| logging.error(error_msg) | |
| raise ValueError(error_msg) | |
| # Use threshold of 0.5 to compute factual_consistency_rate | |
| num_above_threshold = sum(score >= threshold for score in self.scores) | |
| num_total = len(self.scores) | |
| if not num_total: | |
| raise ValueError("No scores available to compute factual consistency rate.") | |
| self.factual_consistency_rate = (num_above_threshold / num_total) * 100 | |
| self.hallucination_rate = 100 - self.factual_consistency_rate | |
| return self.factual_consistency_rate | |