Spaces:
Sleeping
Sleeping
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from ..hf import detect_device | |
| MODEL_ID = "vikhyatk/moondream2" | |
| DEVICE, DTYPE = detect_device() | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| moondream = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| attn_implementation="flash_attention_2", | |
| torch_dtype=DTYPE, | |
| device_map={"": DEVICE}, | |
| ) | |
| moondream.eval() | |
| # Yes, the benchmark test set is stored in the 'train' split... | |
| dataset = load_dataset("BaiqiL/NaturalBench", split="train") | |
| acc = [] | |
| q_acc = [] | |
| i_acc = [] | |
| g_acc = [] | |
| for row in tqdm(dataset): | |
| if row["Question_Type"] == "yes_no": | |
| suffix = " Answer yes or no." | |
| else: | |
| suffix = "" | |
| answers = moondream.batch_answer( | |
| images=[row["Image_0"], row["Image_1"], row["Image_0"], row["Image_1"]], | |
| prompts=[ | |
| row["Question_0"] + suffix, | |
| row["Question_0"] + suffix, | |
| row["Question_1"] + suffix, | |
| row["Question_1"] + suffix, | |
| ], | |
| tokenizer=tokenizer, | |
| ) | |
| expected = [ | |
| row["Image_0_Question_0"], | |
| row["Image_1_Question_0"], | |
| row["Image_0_Question_1"], | |
| row["Image_1_Question_1"], | |
| ] | |
| acc.append(answers[0] == expected[0]) | |
| acc.append(answers[1] == expected[1]) | |
| acc.append(answers[2] == expected[2]) | |
| acc.append(answers[3] == expected[3]) | |
| i_acc.append(answers[0] == expected[0] and answers[2] == expected[2]) | |
| i_acc.append(answers[1] == expected[1] and answers[3] == expected[3]) | |
| q_acc.append(answers[0] == expected[0] and answers[1] == expected[1]) | |
| q_acc.append(answers[2] == expected[2] and answers[3] == expected[3]) | |
| g_acc.append( | |
| answers[0] == expected[0] | |
| and answers[1] == expected[1] | |
| and answers[2] == expected[2] | |
| and answers[3] == expected[3] | |
| ) | |
| print("Overall Accuracy:", sum(acc) / len(acc)) | |
| print("Image Accuracy:", sum(i_acc) / len(i_acc)) | |
| print("Question Accuracy:", sum(q_acc) / len(q_acc)) | |
| print("Group Accuracy:", sum(g_acc) / len(g_acc)) | |