| | import json |
| | import time |
| | import re |
| | import os |
| | import argparse |
| | from datasets import load_dataset |
| | from nltk.tokenize import sent_tokenize |
| | from utils.util import retriveDoc,compute_best_sentence_f1 |
| | from openai import OpenAI |
| | import asyncio, json, torch, math |
| | from typing import List, Tuple |
| | |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
| | from utils.metrics import qa_f1_score |
| | from utils.llmjudge import judge_answer_with_api |
| |
|
| |
|
| | client = OpenAI( |
| | base_url=os.environ.get("OPENAI_BASE_URL"), |
| | api_key=os.environ.get("OPENAI_API_KEY") |
| | ) |
| | |
| |
|
| | tokenizer1 = AutoTokenizer.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True) |
| | model1 = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True,device_map="cuda:0",torch_dtype=torch.bfloat16) |
| |
|
| |
|
| | tok_qwen = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True) |
| | model_qwen = AutoModelForCausalLM.from_pretrained( |
| | "Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True, |
| | device_map="cuda:1",torch_dtype=torch.bfloat16 |
| | ).eval() |
| |
|
| | def get_transformers_answer(prompt, tokenizer, model, max_new_tokens=100, temperature=0.7, top_p=0.9, retries=3, delay=5): |
| | """ |
| | Use transformers model.generate method for inference with retry mechanism, |
| | use chat template to format input, and strip the input prompt part through token-level slicing, |
| | return the newly generated text. |
| | """ |
| | import time |
| | for attempt in range(retries): |
| | try: |
| | |
| | messages = [{"role": "user", "content": prompt}] |
| | |
| | |
| | try: |
| | formatted_prompt = tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True |
| | ) |
| | except Exception as e: |
| | print(f"Unable to apply chat template: {e}, falling back to basic text input") |
| | formatted_prompt = prompt |
| | |
| | |
| | model_inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) |
| | |
| | |
| | generated_ids = model.generate( |
| | **model_inputs, |
| | max_new_tokens=max_new_tokens, |
| | temperature=temperature, |
| | top_p=top_p |
| | ) |
| | |
| | |
| | input_length = model_inputs.input_ids.shape[1] |
| | |
| | |
| | output_ids = generated_ids[0][input_length:] |
| | |
| | |
| | answer = tokenizer.decode(output_ids, skip_special_tokens=True).strip() |
| | return answer |
| | except Exception as e: |
| | print(f"Error on attempt {attempt + 1}: {e}") |
| | if attempt < retries - 1: |
| | print(f"Retrying in {delay} seconds...") |
| | time.sleep(delay) |
| | else: |
| | print("Max retries reached, skipping this request.") |
| | return None |
| |
|
| | def truncate_answer(answer): |
| | """Truncate answer, only take the part before the first period""" |
| | return answer.split('.')[0].strip() if answer else "No answer" |
| |
|
| | def write_to_log(filename, data): |
| | """Write data to log file""" |
| | with open(filename, 'a', encoding='utf-8') as file: |
| | file.write(data + '\n') |
| |
|
| | def remove_think_tags(text: str) -> str: |
| | """Remove all <think> ... </think> blocks""" |
| | return re.sub(r'<think>(.*?)</think>', '', text, flags=re.DOTALL).strip() |
| |
|
| | def build_prompt(context: str, question: str) -> str: |
| | prompt = ( |
| | f"Answer the question based on the given passages. The following are the passages:\n" |
| | f"{context}\n" |
| | f"Answer the question based on the given passages.\n" |
| | f"Question: {question}.\n" |
| | f"Answer:\n" |
| | f"Please first provide your answer in the format of Answer:[Your answer]. Then provide your reasoning process step-by-step.(Only include explicit clues) " |
| | f"At the end of each reasoning step, include a new line that specifies the key information or reference content used in that step. " |
| | f"Please ensure that the [reference content] you include is the complete original sentence or consecutive sentences from the text. Please do not change the punctuation. Do not use ellipses inside the sentence. " |
| | f"Follow this format:\n" |
| | f"Answer: [Your answer]\n" |
| | f"Step-by-step Reasoning:\n" |
| | f"1. [Reasoning step 1]\n" |
| | f"[replaced by your reference content]\n" |
| | f"2. [Reasoning step 2]\n" |
| | f"[replaced by your reference content]\n" |
| | ) |
| | return prompt |
| |
|
| | def extract_final_bullet_passage(answer_text: str): |
| | reasoning_pattern = r"Step-by-step Reasoning:\s*(.*)" |
| | reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL) |
| | if not reasoning_match: |
| | return None, None |
| |
|
| | reasoning_text = reasoning_match.group(1).strip() |
| | bullet_pattern = r"(?m)^(\d+\.\s.*?)(?=(?:\n\d+\.\s)|\Z)" |
| | bullets = re.findall(bullet_pattern, reasoning_text, flags=re.DOTALL) |
| | if not bullets: |
| | print("No bullet blocks found.") |
| | return None, None |
| |
|
| | passage_pattern = re.compile( |
| | r'(?i)(?:\*\*)?passage\s+(\d+)(?:\*\*)?\s*:\s*("([^"]*)"|(.+?))(?=\Z|\n\s*\n|$)', |
| | flags=re.DOTALL |
| | ) |
| | |
| | for bullet in reversed(bullets): |
| | matches = passage_pattern.findall(bullet) |
| | if matches: |
| | last_match = matches[-1] |
| | passage_number = last_match[0] |
| | quoted_snippet = last_match[2] |
| | non_quoted_snippet = last_match[3] |
| | snippet = non_quoted_snippet.strip() if non_quoted_snippet.strip() else quoted_snippet.strip() |
| | return passage_number, snippet |
| |
|
| | return None, None |
| |
|
| | def extract_all_bullet_passages(answer_text: str): |
| | reasoning_pattern = r"Step-by-step Reasoning:\s*(.*)" |
| | reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL) |
| | if not reasoning_match: |
| | return [] |
| |
|
| | reasoning_text = reasoning_match.group(1).strip() |
| | bullet_pattern = re.compile(r"^(\d+\.\s.*?)(?=^\d+\.\s|\Z)", re.MULTILINE | re.DOTALL) |
| | bullets = bullet_pattern.findall(reasoning_text) |
| | if not bullets: |
| | return [] |
| |
|
| | results = [] |
| | for bullet_index, bullet_text in enumerate(bullets, start=1): |
| | results.append({ |
| | 'bullet_index': bullet_index, |
| | 'snippet': bullet_text.strip() |
| | }) |
| | print(results) |
| | return results |
| |
|
| | def extract_evidence(answer_text: str): |
| | reasoning_pattern = r"(?i)Evidence\s*(.*)" |
| | reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL) |
| | if not reasoning_match: |
| | return [] |
| |
|
| | reasoning_text = reasoning_match.group(1).strip() |
| |
|
| | |
| | bullet_pattern = re.compile(r"^(\d+\.\s.*?)(?=^\d+\.\s|\Z)", re.MULTILINE | re.DOTALL) |
| | bullets = bullet_pattern.findall(reasoning_text) |
| | if not bullets: |
| | return [] |
| |
|
| | |
| | start_index = -1 |
| | for i, bullet in enumerate(bullets): |
| | if bullet.strip().startswith("1."): |
| | start_index = i |
| | break |
| |
|
| | if start_index == -1: |
| | return [] |
| |
|
| | |
| | bullets = bullets[start_index:] |
| |
|
| | results = [] |
| | for bullet_index, bullet_text in enumerate(bullets, start=1): |
| | results.append({ |
| | 'bullet_index': bullet_index, |
| | 'snippet': bullet_text.strip() |
| | }) |
| | return results |
| |
|
| |
|
| | def get_answer_with_retry(model, prompt, retries=3, delay=5): |
| | """Call the model to get the answer based on the prompt, with retry on failure.""" |
| | for attempt in range(retries): |
| | try: |
| | completion = client.chat.completions.create( |
| | model=model, |
| | messages=[{'role': 'user', 'content': prompt}] |
| | ) |
| | return completion.choices[0].message.content.strip() |
| | except Exception as e: |
| | print(f"Error on attempt {attempt + 1}: {e}") |
| | if attempt < retries - 1: |
| | print(f"Retrying in {delay} seconds...") |
| | time.sleep(delay) |
| | else: |
| | print("Max retries reached, skipping this request.") |
| | return None |
| |
|
| | def extract_json_from_gpt_response(text: str) -> dict | None: |
| | """ |
| | Finds the first JSON block inside ```json ... ``` or ``` … ``` and returns it as a dict. |
| | """ |
| | |
| | m = re.search(r"```json\s*(\{.*?\})\s*```", text, flags=re.DOTALL) |
| | if not m: |
| | |
| | m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.DOTALL) |
| | if not m: |
| | |
| | m = re.search(r"(\{.*?\})", text, flags=re.DOTALL) |
| | if not m: |
| | return None |
| |
|
| | json_str = m.group(1) |
| | try: |
| | return json.loads(json_str) |
| | except json.JSONDecodeError: |
| | |
| | cleaned = re.sub(r",\s*([\]}])", r"\1", json_str) |
| | try: |
| | return json.loads(cleaned) |
| | except json.JSONDecodeError: |
| | return None |
| |
|
| | async def random_alternative_answer( |
| | question: str, |
| | original_context: str, |
| | unique_sents: List[str], |
| | correct_answer: str |
| | ) -> dict: |
| | """Generate random alternative answer and modified evidence""" |
| | |
| | |
| | numbered = "\n\n".join(f"{j+1}. {s}" for j, s in enumerate(unique_sents)) |
| | prompt = ( |
| | "You are a creative assistant. Given the question below and the original answer, propose a plausible alternative answer that is **different** from the original but still reasonable. " |
| | "Then rewrite the provided sentences to support your alternative answer. When rewriting each sentence, modify only the parts necessary to support the alternative answer. " |
| | "Parts unrelated to the answer must keep their original meaning. Be sure that the modified evidence sentences are sufficient to answer the original question. " |
| | "Output must be strictly in the specified JSON format, with no additional text.\n" |
| | '{\n' |
| | ' "answer": "<your alternative answer here, just provide the answer phrase, no need for complete sentence>",\n' |
| | ' "revised": [\n' |
| | ' "<rewritten sentence 1>",\n' |
| | ' "<rewritten sentence 2>",\n' |
| | ' ...\n' |
| | ' ]\n' |
| | '}\n\n' |
| | f"Question:\n{question}\n\n" |
| | f"Original answer:\n{correct_answer}\n\n" |
| | f"Sentences to rewrite:\n{numbered}" |
| | ) |
| | |
| | print(f"[Alternative Answer] Generating prompt: {prompt}") |
| | |
| | rsp = client.chat.completions.create( |
| | model="gpt-4o", temperature=0.7, |
| | messages=[{"role":"user","content":prompt}] |
| | ) |
| | |
| | js = extract_json_from_gpt_response(rsp.choices[0].message.content) |
| | if not js: |
| | print("[Alternative Answer] Failed to parse JSON") |
| | return {"context": original_context, "answer": "Failed to generate alternative"} |
| | |
| | revised = js["revised"] |
| | alternative = js["answer"] |
| | |
| | |
| | new_ctx = original_context |
| | for old, new in zip(unique_sents, revised): |
| | new_ctx = new_ctx.replace(old, new) |
| | |
| | return {"context": new_ctx, "answer": alternative} |
| |
|
| | def main(): |
| | |
| | parser = argparse.ArgumentParser(description="LastingBench random alternative answer generation") |
| | parser.add_argument("--output", "-o", type=str, default="output_random.jsonl", |
| | help="Output JSONL file path (default: output_random.jsonl)") |
| | parser.add_argument("--dataset_repo", type=str, default="THUDM/LongBench", |
| | help="Dataset repository name (default: THUDM/LongBench)") |
| | parser.add_argument("--dataset_subset", type=str, default="hotpotqa", |
| | help="Dataset subset name (default: hotpotqa)") |
| | parser.add_argument("--split", type=str, default="test", |
| | help="Dataset split (default: test)") |
| | parser.add_argument("--start_idx", type=int, default=0, |
| | help="Starting index for processing (default: 0)") |
| | parser.add_argument("--max_samples", type=int, default=-1, |
| | help="Maximum number of samples to process (-1 for all, default: -1)") |
| | |
| | args = parser.parse_args() |
| | |
| | out_file = args.output |
| | |
| | longbench = load_dataset(args.dataset_repo, args.dataset_subset)[args.split] |
| | |
| | print(f"Output file: {out_file}") |
| | print(f"Dataset: {args.dataset_repo}/{args.dataset_subset}[{args.split}]") |
| | print(f"Total samples: {len(longbench)}") |
| | |
| | count = 0 |
| | |
| | |
| | start_idx = args.start_idx |
| | end_idx = len(longbench) if args.max_samples == -1 else min(start_idx + args.max_samples, len(longbench)) |
| | |
| | print(f"Processing samples from index {start_idx} to {end_idx-1}") |
| | |
| | for idx in range(start_idx, end_idx): |
| | example = longbench[idx] |
| | question = example['input'] |
| | print(f"Question: {question}") |
| | context = example['context'] |
| | correct_answer = example['answers'][0] |
| |
|
| | print(f"Processing example {idx + 1}:") |
| | print(f"Correct Answer: {correct_answer}") |
| |
|
| | |
| | prompt_with_context = build_prompt(context, question) |
| |
|
| | |
| | answer_with_context = get_answer_with_retry('deepseek-r1', prompt_with_context) |
| | |
| | |
| | answer_with_context_simple = ( |
| | answer_with_context |
| | .split("Answer:", 1)[-1] |
| | .split("Step-by-step Reasoning", 1)[0] |
| | .strip() |
| | ) |
| | |
| | print(f"Answer with context: {answer_with_context_simple}") |
| | result = judge_answer_with_api(question, correct_answer, answer_with_context_simple) |
| | print(f"Answer judge result: {result}") |
| | |
| | if not result: |
| | continue |
| |
|
| | answer_with_context = remove_think_tags(answer_with_context or "") |
| | evidence = extract_all_bullet_passages(answer_with_context) |
| |
|
| | page_contents = [] |
| | if evidence: |
| | count += 1 |
| | for ev in evidence: |
| | snippet = ev['snippet'] |
| | result = retriveDoc(context, snippet) |
| | |
| | page_contents += [doc.page_content for doc in result] |
| | |
| | unique_page_contents = list(dict.fromkeys(page_contents)) |
| | aggregated_content = "\n".join(unique_page_contents) |
| | |
| | prompt_final = ( |
| | f"Please answer the question based on the context.\nContext: {aggregated_content}.\n Question: {question}.\n" |
| | f"Please only provide your answer. " |
| | f"Your Answer:" |
| | ) |
| | |
| | final_answer = get_transformers_answer(prompt_final, tokenizer1, model1) |
| | |
| | if judge_answer_with_api(question, correct_answer, final_answer): |
| | print("correct") |
| | else: |
| | print("incorrect") |
| | result_query = retriveDoc(context, question) |
| | page_contents += [doc.page_content for doc in result_query] |
| | |
| | unique_page_contents = list(dict.fromkeys(page_contents)) |
| | |
| | |
| | alternative = asyncio.run( |
| | random_alternative_answer( |
| | question, |
| | context, |
| | unique_page_contents, |
| | correct_answer |
| | ) |
| | ) |
| | |
| | record = { |
| | "question": question, |
| | "answer": alternative["answer"], |
| | "context": alternative["context"] |
| | } |
| |
|
| | |
| | with open(out_file, "a", encoding="utf-8") as fout: |
| | fout.write(json.dumps(record, ensure_ascii=False) + "\n") |
| |
|
| | if __name__ == "__main__": |
| | main() |