| | import os |
| | import json |
| | import base64 |
| | import argparse |
| | import time |
| | import re |
| | from datetime import datetime |
| | from functools import partial |
| | from openai import AzureOpenAI, OpenAI |
| | from volcenginesdkarkruntime import Ark |
| | from multiprocessing import Pool, Manager, Lock |
| |
|
| | |
| | REASONING_MULTIPLE_CHOICE_TEMPLATE = """ |
| | You are an AI assistant evaluating video frames to answer a multiple-choice question. |
| | The user will provide you with a set of video frames and a question with several options (e.g., A, B, C, D). |
| | |
| | First, provide a step-by-step reasoning process that analyzes the video frames and leads to your conclusion. |
| | After your reasoning, provide the final answer in a JSON block. The JSON object must contain a single key "answer" with the value being one of 'A', 'B', 'C', or 'D'. |
| | |
| | Your output should follow this format exactly: |
| | <Your step-by-step reasoning here> |
| | ```json |
| | {"answer": "A"} |
| | ``` |
| | Do not include any other text after the JSON block. |
| | """ |
| |
|
| |
|
| | def parse_arguments(): |
| | """ |
| | Parse command line arguments for evaluation configuration. |
| | |
| | Returns: |
| | argparse.Namespace: Parsed command line arguments |
| | """ |
| | parser = argparse.ArgumentParser(description="Video QA Evaluation Framework") |
| |
|
| | |
| | parser.add_argument( |
| | "--target-model", |
| | "-tm", |
| | type=str, |
| | required=True, |
| | help="Model to be evaluated (e.g., gpt-4o, gpt-4-vision-preview)", |
| | ) |
| |
|
| | |
| | parser.add_argument( |
| | "--frame-num", |
| | "-fn", |
| | type=int, |
| | default=32, |
| | help="Number of frames to uniformly sample from each video (default: 32)", |
| | ) |
| | parser.add_argument( |
| | "--frames-path", |
| | "-fp", |
| | type=str, |
| | required=True, |
| | help="Absolute path to the base directory containing video frame folders.", |
| | ) |
| | parser.add_argument( |
| | "--data-file", |
| | "-df", |
| | type=str, |
| | required=True, |
| | help="Absolute path to the JSON file containing the evaluation dataset.", |
| | ) |
| |
|
| | |
| | parser.add_argument( |
| | "--max-retry-times", |
| | "-mr", |
| | type=int, |
| | default=10, |
| | help="Maximum number of retries for API calls (default: 10)", |
| | ) |
| | parser.add_argument( |
| | "--pool-processes", |
| | "-pp", |
| | type=int, |
| | default=20, |
| | help="Number of parallel processes for evaluation (default: 20)", |
| | ) |
| |
|
| | |
| | parser.add_argument( |
| | "--base_url", type=str, required=True, help="Azure OpenAI endpoint URL." |
| | ) |
| | parser.add_argument( |
| | "--api_key", type=str, required=True, help="Azure OpenAI API key." |
| | ) |
| |
|
| | return parser.parse_args() |
| |
|
| |
|
| | def save_json_file(data, output_file): |
| | """ |
| | Save data to a JSON file. |
| | |
| | Args: |
| | data (dict or list): Data to be saved. |
| | output_file (str): Path to the output file. |
| | """ |
| | with open(output_file, "w", encoding="utf-8") as f: |
| | json.dump(data, f, indent=4) |
| |
|
| |
|
| | def extract_json_from_response(response): |
| | """ |
| | Extracts a JSON object from a string that contains reasoning followed by a tagged JSON block. |
| | |
| | Args: |
| | response (str): The raw response string from the model. |
| | |
| | Returns: |
| | dict or None: Parsed JSON object or None if no valid JSON block is found. |
| | """ |
| | if not response: |
| | return None |
| | try: |
| | |
| | match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL) |
| | if match: |
| | json_str = match.group(1) |
| | return json.loads(json_str) |
| | return None |
| | except (json.JSONDecodeError, IndexError): |
| | return None |
| |
|
| |
|
| | def calculate_metrics(results): |
| | """ |
| | Calculate evaluation metrics from the results. |
| | |
| | Args: |
| | results (list): List of results with 'is_correct' field. |
| | |
| | Returns: |
| | dict: Dictionary containing calculated metrics. |
| | """ |
| | total_samples = len(results) |
| | if total_samples == 0: |
| | return { |
| | "total_samples": 0, |
| | "answered_samples": 0, |
| | "correct_answers": 0, |
| | "accuracy": 0.0, |
| | } |
| |
|
| | answered_samples = sum(1 for x in results if x.get("model_answer") is not None) |
| | correct_answers = sum(1 for x in results if x.get("is_correct")) |
| |
|
| | accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0 |
| |
|
| | metrics = { |
| | "total_samples": total_samples, |
| | "answered_samples": answered_samples, |
| | "correct_answers": correct_answers, |
| | "accuracy": accuracy, |
| | } |
| |
|
| | return metrics |
| |
|
| |
|
| | def call_single_model(client, messages, model, item_id, max_retry_times): |
| | """ |
| | Make a single API call to the specified model with retry logic. |
| | |
| | Args: |
| | client: OpenAI client instance. |
| | messages (list): List of messages for the API call. |
| | model (str): Model name to use. |
| | item_id (str): ID of the item being processed (for error logging). |
| | max_retry_times (int): Maximum number of retries. |
| | |
| | Returns: |
| | str or None: Model response or None if all retries failed. |
| | """ |
| | if "doubao" in model: |
| | max_tokens = 32768 |
| | else: |
| | max_tokens = 65535 |
| | retry_times = 0 |
| | while retry_times < max_retry_times: |
| | try: |
| | |
| | completion = client.chat.completions.create( |
| | model=model, messages=messages, max_tokens=max_tokens |
| | ) |
| | return completion.choices[0].message.content |
| | except Exception as e: |
| | retry_times += 1 |
| | print( |
| | f"Error processing item {item_id} with model {model}: {str(e)}. Retrying ({retry_times}/{max_retry_times})..." |
| | ) |
| | if retry_times == max_retry_times: |
| | error_log_file = f"error_log_{model.replace('/', '_')}.txt" |
| | with open(error_log_file, "a") as f: |
| | f.write( |
| | f"Error processing item {item_id} with model {model} after {max_retry_times} retries: {str(e)}\n" |
| | ) |
| | return None |
| | time.sleep(5) |
| |
|
| |
|
| | def evaluate_single_item( |
| | data_item, frames, target_model, api_key, base_url, max_retry_times |
| | ): |
| | """ |
| | Evaluate a single data item using the target model and perform exact match. |
| | |
| | Args: |
| | data_item (dict): Dictionary containing question and answer data. |
| | frames (list): List of encoded video frames. |
| | target_model (str): Model to be evaluated. |
| | api_key (str): API key. |
| | base_url (str): API base URL. |
| | max_retry_times (int): Maximum number of retries. |
| | |
| | Returns: |
| | dict: Evaluation result. |
| | """ |
| | if "ark" in base_url: |
| | client = Ark( |
| | base_url=base_url, |
| | api_key=api_key, |
| | ) |
| | elif "aliyun" in base_url or "127.0.0.1" in base_url: |
| | client = OpenAI(api_key=api_key, base_url=base_url) |
| | else: |
| | client = AzureOpenAI( |
| | api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url |
| | ) |
| |
|
| | |
| | messages = [ |
| | {"role": "system", "content": REASONING_MULTIPLE_CHOICE_TEMPLATE}, |
| | { |
| | "role": "user", |
| | "content": [ |
| | {"type": "text", "text": "Here are the video frames:"}, |
| | *frames, |
| | {"type": "text", "text": f"Question: {data_item['question']}"}, |
| | ], |
| | }, |
| | ] |
| |
|
| | response = call_single_model( |
| | client, messages, target_model, data_item["key"], max_retry_times |
| | ) |
| |
|
| | is_correct = False |
| | model_answer_cleaned = None |
| | parsed_json = None |
| |
|
| | if response: |
| | parsed_json = extract_json_from_response(response) |
| | if parsed_json and "answer" in parsed_json: |
| | model_answer_cleaned = str(parsed_json["answer"]).strip().upper() |
| | gold_answer = data_item["answer"].strip().upper() |
| | if model_answer_cleaned == gold_answer: |
| | is_correct = True |
| |
|
| | |
| | result = { |
| | **data_item, |
| | "model_reasoning_and_answer": response, |
| | "model_answer_raw": parsed_json.get("answer") if parsed_json else None, |
| | "model_answer": model_answer_cleaned, |
| | "is_correct": is_correct, |
| | } |
| |
|
| | return result |
| |
|
| |
|
| | def encode_image(image_path): |
| | """ |
| | Encode an image file to base64 string. |
| | |
| | Args: |
| | image_path (str): Path to the image file. |
| | |
| | Returns: |
| | str: Base64 encoded image string. |
| | """ |
| | with open(image_path, "rb") as image_file: |
| | return base64.b64encode(image_file.read()).decode("utf-8") |
| |
|
| |
|
| | def process_frames(frames_path, frame_num): |
| | """ |
| | Process and uniformly sample video frames from a directory, then encode them. |
| | |
| | Args: |
| | frames_path (str): Path to the directory containing video frames. |
| | frame_num (int): The number of frames to sample. |
| | |
| | Returns: |
| | list: List of encoded frame objects for API consumption. |
| | """ |
| | if not os.path.isdir(frames_path): |
| | print(f"Warning: Frame directory not found at {frames_path}") |
| | return [] |
| |
|
| | frame_files = [ |
| | f |
| | for f in os.listdir(frames_path) |
| | if f.startswith("frame_") and f.endswith(".jpg") |
| | ] |
| | |
| | frame_files.sort(key=lambda x: int(x.split("_")[1].split(".")[0])) |
| |
|
| | frame_path_list = [os.path.join(frames_path, f) for f in frame_files] |
| | total_frames = len(frame_path_list) |
| |
|
| | if total_frames == 0: |
| | return [] |
| |
|
| | |
| | if total_frames > frame_num: |
| | indices = [int(i * total_frames / frame_num) for i in range(frame_num)] |
| | sampled_paths = [frame_path_list[i] for i in indices] |
| | else: |
| | sampled_paths = frame_path_list |
| |
|
| | |
| | base64_images = [encode_image(path) for path in sampled_paths] |
| |
|
| | |
| | return [ |
| | {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_img}"}} |
| | for b64_img in base64_images |
| | ] |
| |
|
| |
|
| | def process_single_data( |
| | data_item, args, shared_results, progress_counter, total_items, locks |
| | ): |
| | """ |
| | Process a single data item in a multiprocessing context. |
| | |
| | Args: |
| | data_item (dict): Single data item to process. |
| | args: Command line arguments. |
| | shared_results: Shared list for storing results. |
| | progress_counter: Shared counter for progress tracking. |
| | total_items (int): Total number of items to process. |
| | locks (dict): Dictionary of locks for thread-safe operations. |
| | """ |
| | item_key = data_item["key"] |
| | try: |
| | |
| | specific_frames_path = os.path.join(args.frames_path, item_key) |
| | frames = process_frames(specific_frames_path, args.frame_num) |
| |
|
| | if not frames: |
| | raise FileNotFoundError( |
| | f"No frames found or processed for key '{item_key}' at path '{specific_frames_path}'" |
| | ) |
| |
|
| | result = evaluate_single_item( |
| | data_item, |
| | frames, |
| | args.target_model, |
| | args.api_key, |
| | args.base_url, |
| | args.max_retry_times, |
| | ) |
| |
|
| | if result is not None: |
| | with locks["results"]: |
| | shared_results.append(result) |
| | |
| | data_filename_base = os.path.splitext(os.path.basename(args.data_file))[ |
| | 0 |
| | ] |
| | model_name_safe = args.target_model.replace("/", "_") |
| | output_prefix = ( |
| | f"{model_name_safe}_{data_filename_base}_{args.frame_num}frames" |
| | ) |
| | results_output_file = f"{output_prefix}_results.json" |
| | |
| | save_json_file(list(shared_results), results_output_file) |
| |
|
| | except Exception as e: |
| | print(f"Error processing video key {item_key}: {str(e)}") |
| | with locks["file"]: |
| | error_log_file = f"error_log_{args.target_model.replace('/', '_')}.txt" |
| | with open(error_log_file, "a") as f: |
| | f.write(f"Critical error processing video key {item_key}: {str(e)}\n") |
| | finally: |
| | |
| | with locks["counter"]: |
| | progress_counter.value += 1 |
| | print( |
| | f"\rProcessed: {progress_counter.value}/{total_items} videos...", |
| | end="", |
| | flush=True, |
| | ) |
| |
|
| |
|
| | def load_test_data(json_file): |
| | """ |
| | Load test data from a JSON file. |
| | |
| | Args: |
| | json_file (str): Path to the JSON file. |
| | |
| | Returns: |
| | list: List of test data items. |
| | """ |
| | try: |
| | with open(json_file, "r", encoding="utf-8") as f: |
| | return json.load(f) |
| | except FileNotFoundError: |
| | print(f"Error: Data file not found at {json_file}") |
| | exit(1) |
| | except json.JSONDecodeError: |
| | print(f"Error: Could not decode JSON from {json_file}") |
| | exit(1) |
| |
|
| |
|
| | def main(): |
| | """ |
| | Main function to run the video QA evaluation framework. |
| | """ |
| | args = parse_arguments() |
| |
|
| | print("--- Evaluation Configuration ---") |
| | print(f"Target Model: {args.target_model}") |
| | print(f"Frames to Sample: {args.frame_num}") |
| | print(f"Frames Base Path: {args.frames_path}") |
| | print(f"Data File: {args.data_file}") |
| | print(f"Parallel Processes: {args.pool_processes}") |
| | print("---------------------------------") |
| |
|
| | |
| | error_log_file = f"error_log_{args.target_model.replace('/', '_')}.txt" |
| | with open(error_log_file, "w") as f: |
| | f.write( |
| | f"=== Error Log Started at {datetime.now()} for model {args.target_model} ===\n" |
| | ) |
| |
|
| | |
| | data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0] |
| | model_name_safe = args.target_model.replace("/", "_") |
| | output_prefix = f"{model_name_safe}_{data_filename_base}_{args.frame_num}frames" |
| |
|
| | results_output_file = f"{output_prefix}_results.json" |
| | metrics_output_file = f"{output_prefix}_metrics.json" |
| |
|
| | |
| | test_data = load_test_data(args.data_file) |
| | total_videos = len(test_data) |
| | print(f"\nLoaded {total_videos} videos to process.") |
| |
|
| | |
| | with Manager() as manager: |
| | shared_results = manager.list() |
| | progress_counter = manager.Value("i", 0) |
| |
|
| | locks = { |
| | "results": manager.Lock(), |
| | "file": manager.Lock(), |
| | "counter": manager.Lock(), |
| | } |
| |
|
| | |
| | process_func = partial( |
| | process_single_data, |
| | args=args, |
| | shared_results=shared_results, |
| | progress_counter=progress_counter, |
| | total_items=total_videos, |
| | locks=locks, |
| | ) |
| |
|
| | |
| | with Pool(processes=args.pool_processes) as pool: |
| | pool.map(process_func, test_data) |
| |
|
| | |
| | all_results = list(shared_results) |
| |
|
| | print(f"\n\nProcessing complete for model: {args.target_model}") |
| |
|
| | |
| | final_metrics = calculate_metrics(all_results) |
| | save_json_file(final_metrics, metrics_output_file) |
| | print(f"\nMetrics saved to: {metrics_output_file}") |
| | print(json.dumps(final_metrics, indent=4)) |
| |
|
| | |
| | save_json_file(all_results, results_output_file) |
| | print(f"Detailed results saved to: {results_output_file}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|