Spaces:
Configuration error
Configuration error
| import os | |
| import json | |
| import argparse | |
| import sys | |
| # Import metrics directly from the local file | |
| from metrics import mean_recall_at_k, mean_average_precision, mean_inv_ranking, mean_ranking | |
| def load_json_file(file_path): | |
| """Load JSON data from a file""" | |
| with open(file_path, 'r') as f: | |
| return json.load(f) | |
| def main(): | |
| base_directory = os.getcwd() | |
| parser = argparse.ArgumentParser(description='Evaluate document ranking performance on training data') | |
| parser.add_argument('--pre_ranking', type=str, default='shuffled_pre_ranking.json', | |
| help='Path to pre-ranking JSON file') | |
| parser.add_argument('--re_ranking', type=str, default='predictions2.json', | |
| help='Path to re-ranked JSON file') | |
| parser.add_argument('--gold', type=str, default='train_gold_mapping.json', | |
| help='Path to gold standard mapping JSON file (training only)') | |
| parser.add_argument('--train_queries', type=str, default='train_queries.json', | |
| help='Path to training queries JSON file') | |
| parser.add_argument('--k_values', type=str, default='3,5,10,20', | |
| help='Comma-separated list of k values for Recall@k') | |
| parser.add_argument('--base_dir', type=str, | |
| default=f'{base_directory}/Patent_Retrieval/datasets', | |
| help='Base directory for data files') | |
| args = parser.parse_args() | |
| # Ensure all paths are relative to base_dir if they're not absolute | |
| def get_full_path(path): | |
| if os.path.isabs(path): | |
| return path | |
| return os.path.join(args.base_dir, path) | |
| # Load the training queries | |
| print("Loading training queries...") | |
| train_queries = load_json_file(get_full_path(args.train_queries)) | |
| print(f"Loaded {len(train_queries)} training queries") | |
| # Load the ranking data and gold standard | |
| print("Loading ranking data and gold standard...") | |
| pre_ranking = load_json_file(get_full_path(args.pre_ranking)) | |
| re_ranking = load_json_file(get_full_path(args.re_ranking)) | |
| gold_mapping = load_json_file(get_full_path(args.gold)) | |
| # Filter to include only training queries | |
| pre_ranking = {fan: docs for fan, docs in pre_ranking.items() if fan in train_queries} | |
| re_ranking = {fan: docs for fan, docs in re_ranking.items() if fan in train_queries} # Fixed this line | |
| gold_mapping = {fan: docs for fan, docs in gold_mapping.items() if fan in train_queries} | |
| # Parse k values | |
| k_values = [int(k) for k in args.k_values.split(',')] | |
| # Prepare data for metrics calculation | |
| query_fans = set(gold_mapping.keys()) & set(pre_ranking.keys()) & set(re_ranking.keys()) | |
| if not query_fans: | |
| print("Error: No common query FANs found across all datasets!") | |
| return | |
| print(f"Evaluating rankings for {len(query_fans)} training queries...") | |
| # Extract true and predicted labels for both rankings | |
| true_labels = [gold_mapping[fan] for fan in query_fans] | |
| pre_ranking_labels = [pre_ranking[fan] for fan in query_fans] | |
| re_ranking_labels = [re_ranking[fan] for fan in query_fans] | |
| # Calculate metrics for pre-ranking | |
| print("\nPre-ranking performance (training queries only):") | |
| for k in k_values: | |
| recall_at_k = mean_recall_at_k(true_labels, pre_ranking_labels, k=k) | |
| print(f" Recall@{k}: {recall_at_k:.4f}") | |
| map_score = mean_average_precision(true_labels, pre_ranking_labels) | |
| print(f" MAP: {map_score:.4f}") | |
| inv_rank = mean_inv_ranking(true_labels, pre_ranking_labels) | |
| print(f" Mean Inverse Rank: {inv_rank:.4f}") | |
| rank = mean_ranking(true_labels, pre_ranking_labels) | |
| print(f" Mean Rank: {rank:.2f}") | |
| # Calculate metrics for re-ranking | |
| print("\nRe-ranking performance (training queries only):") | |
| for k in k_values: | |
| recall_at_k = mean_recall_at_k(true_labels, re_ranking_labels, k=k) | |
| print(f" Recall@{k}: {recall_at_k:.4f}") | |
| map_score = mean_average_precision(true_labels, re_ranking_labels) | |
| print(f" MAP: {map_score:.4f}") | |
| inv_rank = mean_inv_ranking(true_labels, re_ranking_labels) | |
| print(f" Mean Inverse Rank: {inv_rank:.4f}") | |
| rank = mean_ranking(true_labels, re_ranking_labels) | |
| print(f" Mean Rank: {rank:.2f}") | |
| if __name__ == "__main__": | |
| main() | |