| | import json |
| | import time |
| | import os |
| | import argparse |
| | from datasets import load_dataset |
| | from openai import OpenAI |
| | from tqdm import tqdm |
| | from utils.metrics import qa_f1_score, qa_em_score |
| |
|
| | |
| | client = OpenAI( |
| | api_key=os.environ.get("OPENAI_API_KEY"), |
| | base_url=os.environ.get("OPENAI_BASE_URL") |
| | ) |
| |
|
| | def get_openai_response(prompt, model="gpt-4o", retries=3, delay=2): |
| | """Call OpenAI API to get response with retry mechanism""" |
| | for attempt in range(retries): |
| | try: |
| | completion = client.chat.completions.create( |
| | model=model, |
| | messages=[{'role': 'user', 'content': prompt}], |
| | max_tokens=100 |
| | ) |
| | return completion.choices[0].message.content.strip() |
| | except Exception as e: |
| | print(f"Attempt {attempt + 1} failed: {e}") |
| | if attempt < retries - 1: |
| | print(f"Retrying in {delay} seconds...") |
| | time.sleep(delay) |
| | else: |
| | print("Max retries reached. Skipping this request.") |
| | return "Failed to get response" |
| |
|
| | def rephrase_question_api(question, model_name, rephrase_type="opposite"): |
| | """Use OpenAI API to rephrase question (English prompt)""" |
| | if rephrase_type == "opposite": |
| | prompt = f"""Please rephrase the following question to have the exact opposite meaning. |
| | Question: {question} |
| | |
| | Return only the rephrased question with the opposite meaning, without any explanations or other content.""" |
| | elif rephrase_type == "similar": |
| | prompt = f"""Please rephrase the following question to be synonymous, maintaining the original meaning but using different wording: |
| | Question: {question} |
| | |
| | Return only the rephrased question, without any explanations or other content.""" |
| | else: |
| | raise ValueError(f"Invalid rephrase_type: {rephrase_type}. Must be 'opposite' or 'similar'.") |
| | |
| | return get_openai_response(prompt, model=model_name) |
| |
|
| | def answer_question_with_context_api(question, context, model_name, max_tokens_for_answer=30): |
| | """Use OpenAI API to answer question based on context (English prompt)""" |
| | prompt = f"""Please answer the question based on the following context: |
| | |
| | Context: |
| | {context} |
| | |
| | Question: {question} |
| | |
| | Only output the answer, no any other text. If the answer is not in the context, please say "I don't know". |
| | |
| | Answer:""" |
| | try: |
| | completion = client.chat.completions.create( |
| | model=model_name, |
| | messages=[{'role': 'user', 'content': prompt}], |
| | max_tokens=max_tokens_for_answer |
| | ) |
| | return completion.choices[0].message.content.strip() |
| | except Exception as e: |
| | print(f"Answer generation failed for model {model_name}: {e}") |
| | return "Failed to get answer" |
| |
|
| | def main(args): |
| | |
| | print(f"Loading dataset {args.dataset_name}, subset {args.dataset_subset}...") |
| | try: |
| | dataset = load_dataset(args.dataset_name, args.dataset_subset)["test"] |
| | print(f"Successfully loaded dataset with {len(dataset)} samples.") |
| | except Exception as e: |
| | print(f"Failed to load dataset: {e}") |
| | return |
| |
|
| | em_match_count = 0 |
| | successfully_processed_samples = 0 |
| |
|
| | num_samples_to_process = len(dataset) if args.sample_count == -1 else min(args.sample_count, len(dataset)) |
| | |
| | print(f"Processing {num_samples_to_process} samples. Rephrasing with GPT-4o (opposite meaning). Answering with {args.model_name} (max 30 tokens for answer)...") |
| |
|
| | for i in tqdm(range(num_samples_to_process), desc="Processing samples"): |
| | example = dataset[i] |
| | original_question = example['input'] |
| | context = example['context'] |
| | ground_truth_answers = example['answers'] |
| |
|
| | print(f"Original question: {original_question}") |
| | |
| | |
| | rephrased_question = rephrase_question_api(original_question, "gpt-4o", args.rephrase_type) |
| | print(f"Rephrased question (opposite): {rephrased_question}") |
| | |
| | if rephrased_question == "Failed to get response" or rephrased_question == "Failed to rephrase question": |
| | print(f"Skipping sample {i+1} due to rephrasing failure.") |
| | continue |
| | |
| | |
| | rephrased_answer = answer_question_with_context_api(rephrased_question, context, args.model_name, max_tokens_for_answer=30) |
| | |
| |
|
| | if rephrased_answer == "Failed to get answer": |
| | print(f"Skipping sample {i+1} due to answer generation failure.") |
| | continue |
| |
|
| | if not ground_truth_answers: |
| | print(f"Skipping sample {i+1} due to missing ground truth answers.") |
| | continue |
| | |
| | successfully_processed_samples += 1 |
| | sample_had_em_match = False |
| | for gt_ans in ground_truth_answers: |
| | em = qa_em_score(rephrased_answer, gt_ans) |
| | if em > 0: |
| | sample_had_em_match = True |
| | break |
| | |
| | if sample_had_em_match: |
| | em_match_count += 1 |
| | |
| |
|
| | if successfully_processed_samples > 0: |
| | print(f"\n--- Evaluation Summary ---") |
| | print(f"Answering Model : {args.model_name}") |
| | print(f"Dataset : {args.dataset_name} ({args.dataset_subset})") |
| | print(f"Successfully Processed Samples for Evaluation: {successfully_processed_samples}") |
| | print(f"Max Answer Tokens: 30") |
| | print(f"Count of EM with original ground truth (after rephrase): {em_match_count}") |
| | else: |
| | print("\nNo samples were processed adequately to provide an evaluation summary.") |
| | |
| | print("Processing complete!") |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Rephrase questions to opposite meaning with GPT-4o, answer with specified OpenAI model, then count EM against original GT.") |
| | parser.add_argument("--model_name", type=str, default="gpt-4o", help="Name of the OpenAI model to use for Answering.") |
| | parser.add_argument("--dataset_name", type=str, default="THUDM/LongBench", help="Name of the Hugging Face dataset.") |
| | parser.add_argument("--dataset_subset", type=str, default="2wikimqa", help="Subset of the dataset.") |
| | parser.add_argument("--sample_count", type=int, default=-1, help="Number of samples to process. -1 for all samples.") |
| | parser.add_argument("--rephrase_type", type=str, default="opposite", choices=["opposite", "similar"], help="Type of rephrasing: 'opposite' for opposite meaning or 'similar' for similar meaning.") |
| | |
| | args = parser.parse_args() |
| | main(args) |