Spaces:
Running
Running
| from commonsenseConstraint import evaluation as commonsense_eval | |
| from hardConstraint import evaluation as hard_eval | |
| import json | |
| from tqdm import tqdm | |
| from datasets import load_dataset | |
| def load_line_json_data(filename): | |
| data = [] | |
| with open(filename, 'r', encoding='utf-8') as f: | |
| for line in f.read().strip().split('\n'): | |
| unit = json.loads(line) | |
| data.append(unit) | |
| return data | |
| def count_true_false(data): | |
| """Count the number of true and false values in a list.""" | |
| true_count = data.count(True) | |
| false_count = data.count(False) | |
| return true_count, false_count | |
| def statistics(commonsense_statistic): | |
| """Generate statistics for each level and day in the given data with a different structure.""" | |
| result = {level: {day: {} for day in commonsense_statistic[level]} for level in commonsense_statistic} | |
| for level, days in commonsense_statistic.items(): | |
| for day, dicts in days.items(): | |
| for dct in dicts: | |
| if dct: | |
| for key, data in dct.items(): | |
| true_count, false_count = count_true_false(data) | |
| if key not in result[level][day]: | |
| result[level][day][key] = {"true": 0, "false": 0} | |
| result[level][day][key]["true"] += true_count | |
| result[level][day][key]["false"] += false_count | |
| return result | |
| def eval_score(validation_or_test: str, file_path: str, TOKEN): | |
| if validation_or_test == 'validation': | |
| query_data_list = load_dataset('osunlp/TravelPlannerEval','validation',token=TOKEN)['validation'] | |
| elif validation_or_test == 'test': | |
| query_data_list = load_dataset('osunlp/TravelPlannerEval','test',token=TOKEN)['test'] | |
| query_data_list = [x for x in query_data_list] | |
| hardConstraint_statistic= {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']} | |
| commonsenseConstraint_statistic = {level:{day:[] for day in [3,5,7]} for level in ['easy','medium','hard']} | |
| tested_plans = load_line_json_data(file_path) | |
| delivery_cnt = 0 | |
| plan_constraint_store = [] | |
| for idx in tqdm(range(0,len(query_data_list))): | |
| query_data = query_data_list[idx] | |
| tested_plan = tested_plans[idx] | |
| if type(query_data) == str: | |
| query_data = eval(query_data) | |
| if type(tested_plan) == str: | |
| tested_plan = eval(tested_plan) | |
| if type(query_data['local_constraint']) == str: | |
| query_data['local_constraint'] = eval(query_data['local_constraint']) | |
| if tested_plan['plan']: | |
| delivery_cnt += 1 | |
| commonsense_info_box = commonsense_eval(query_data,tested_plan['plan']) | |
| else: | |
| commonsense_info_box = None | |
| if commonsense_info_box and commonsense_info_box['is_not_absent'][0] and commonsense_info_box['is_valid_information_in_sandbox'][0]: | |
| hard_info_box = hard_eval(query_data,tested_plan['plan']) | |
| else: | |
| hard_info_box = None | |
| plan_constraint_store.append({'commonsense_constraint':commonsense_info_box,'hard_constraint':hard_info_box}) | |
| commonsenseConstraint_statistic[query_data['level']][query_data['days']].append(commonsense_info_box) | |
| hardConstraint_statistic[query_data['level']][query_data['days']].append(hard_info_box) | |
| commonsenseConstraint_statistic_processed = statistics(commonsenseConstraint_statistic) | |
| hardConstraint_statistic_processed = statistics(hardConstraint_statistic) | |
| # print(commonsenseConstraint_statistic_processed) | |
| # print(hardConstraint_statistic_processed) | |
| constraint_record = {key: {day: {'house rule':0, 'cuisine':0, 'room type':0, 'transportation':0} for day in [3,5,7]} for key in ['medium','hard']} | |
| constraint_mapping = {'house rule':'valid_room_rule','cuisine':'valid_cuisine','room type':'valid_room_type','transportation':'valid_transportation'} | |
| mapping_constraint_record = {key: {day: {'valid_room_rule':0, 'valid_cuisine':0, 'valid_room_type':0, 'valid_transportation':0} for day in [3,5,7]} for key in ['medium','hard']} | |
| count_record = {key:{day:0 for day in [3,5,7]} for key in ['easy','medium','hard']} | |
| for unit in query_data_list: | |
| count_record[unit['level']][unit['days']] += 1 | |
| for key in constraint_record['medium'][3]: | |
| if unit['local_constraint'][key] != None: | |
| constraint_record[unit['level']][unit['days']][key] += 1 | |
| mapping_constraint_record[unit['level']][unit['days']][constraint_mapping[key]] += 1 | |
| data_record = {key:{day:[] for day in [3,5,7]} for key in ['easy','medium','hard']} | |
| constraint_dis_record = {"commonsense":{"pass":0,"total":0},"hard":{"pass":0,"total":0}} | |
| for constraint in ['commonsense','hard']: | |
| if constraint == 'commonsense': | |
| constraint_statistic = commonsenseConstraint_statistic_processed | |
| elif constraint == 'hard': | |
| constraint_statistic = hardConstraint_statistic_processed | |
| key_dict = {'commonsense':['is_valid_information_in_current_city','is_valid_information_in_sandbox','is_reasonalbe_visiting_city','is_valid_restaurants','is_valid_transportation','is_valid_attractions','is_valid_accommodation','is_not_absent'],'hard':['valid_cost','valid_room_rule','valid_cuisine','valid_room_type','valid_transportation']} | |
| for key in constraint_statistic: | |
| # level | |
| for key2 in constraint_statistic[key]: | |
| # day | |
| # print(key2) | |
| # key2 = eval(key2) | |
| if key2 == -1: | |
| print(constraint_statistic[key]) | |
| exit(0) | |
| for key3 in key_dict[constraint]: | |
| data_record[key][key2].append('0/0') | |
| if key3 in constraint_statistic[key][key2]: | |
| constraint_dis_record[constraint]['pass'] += constraint_statistic[key][key2][key3]['true'] | |
| if constraint == 'hard': | |
| if key == 'hard' and key3 in ['valid_room_rule','valid_cuisine','valid_room_type','valid_transportation']: | |
| data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{mapping_constraint_record[key][key2][key3]}" | |
| constraint_dis_record[constraint]['total'] += mapping_constraint_record[key][key2][key3] | |
| elif key == 'medium' and key3 in ['valid_room_rule','valid_cuisine','valid_room_type']: | |
| data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{mapping_constraint_record[key][key2][key3]}" | |
| constraint_dis_record[constraint]['total'] += mapping_constraint_record[key][key2][key3] | |
| else: | |
| data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{count_record[key][key2]}" | |
| if key3 in ['valid_cost','valid_visitng_city_number','valid_days']: | |
| constraint_dis_record[constraint]['total'] += count_record[key][key2] | |
| else: | |
| data_record[key][key2][-1] = f"{constraint_statistic[key][key2][key3]['true']}/{count_record[key][key2]}" | |
| constraint_dis_record[constraint]['total'] += count_record[key][key2] | |
| final_all_cnt = 0 | |
| final_commonsense_cnt = 0 | |
| final_hardConstraint_cnt = 0 | |
| final_all_cnt_map = {level:0 for level in ['easy','medium','hard']} | |
| for idx in (range(0,len(query_data_list))): | |
| if plan_constraint_store[idx]['commonsense_constraint']: | |
| final_commonsense_pass = True | |
| final_hardConstraint_pass = True | |
| for item in plan_constraint_store[idx]['commonsense_constraint']: | |
| if plan_constraint_store[idx]['commonsense_constraint'][item][0] is not None and not plan_constraint_store[idx]['commonsense_constraint'][item][0]: | |
| final_commonsense_pass = False | |
| break | |
| if plan_constraint_store[idx]['hard_constraint'] is None: | |
| continue | |
| for item in plan_constraint_store[idx]['hard_constraint']: | |
| if plan_constraint_store[idx]['hard_constraint'][item][0] is not None and plan_constraint_store[idx]['hard_constraint'][item][0] == False: | |
| final_hardConstraint_pass = False | |
| break | |
| if final_commonsense_pass: | |
| final_commonsense_cnt += 1 | |
| if final_hardConstraint_pass: | |
| final_hardConstraint_cnt += 1 | |
| if final_commonsense_pass and final_hardConstraint_pass: | |
| final_all_cnt += 1 | |
| final_all_cnt_map[query_data_list[idx]['level']] += 1 | |
| result = {} | |
| if validation_or_test == 'validation': | |
| result['Delivery Rate'] = delivery_cnt / 180 | |
| result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 1440 | |
| result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 180 | |
| result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 420 | |
| result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 180 | |
| result['Final Pass Rate'] = final_all_cnt / 180 | |
| elif validation_or_test == 'test': | |
| result['Delivery Rate'] = delivery_cnt / 1000 | |
| result['Commonsense Constraint Micro Pass Rate'] = constraint_dis_record['commonsense']['pass'] / 8000 | |
| result['Commonsense Constraint Macro Pass Rate'] = final_commonsense_cnt / 1000 | |
| result['Hard Constraint Micro Pass Rate'] = constraint_dis_record['hard']['pass'] / 2290 | |
| result['Hard Constraint Macro Pass Rate'] = final_hardConstraint_cnt / 1000 | |
| result['Final Pass Rate'] = final_all_cnt / 1000 | |
| return result | |