File size: 14,058 Bytes
f472b08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 |
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script provides a simple evaluation pipeline for multiple-choice reasoning datasets
(e.g., BoolQ, HellaSwag, ARC, OpenBookQA, Winogrande) with different composition strategies.
Usage examples:
python arrow_phi3_mini.py --strategy base --ds_name arc-challenge
python arrow_phi3_mini.py --strategy arrow --ds_name boolq
python arrow_phi3_mini.py --strategy gks --ds_name hswag
Key features:
- Supports three strategies:
β’ "base" β Evaluate the quantized base model directly
β’ "arrow" β Use Arrow modular routing with task-specific adapters
β’ "gks" β Use Arrow + GenKnowSub (subtracting general-domain knowledge)
- Loads evaluation datasets from the Hugging Face Hub
- Implements a batched evaluation loop that computes per-option likelihoods and selects
the answer with the lowest average loss
- Reports simple accuracy
Implementation details:
- The base model is quantized to 4-bit using `BitsAndBytesConfig` (nf4, bf16 compute).
- For Arrow and GKS, task-specific adapters are loaded from the Hugging Face Hub:
TahaBa/phi3-mini-clustered-flan/ts_expert_i
- Task-specific adapters were trained on 10 clusters of FLAN tasks.
- The clusters were created using Model-Based Clustering (MBC):
1. Train a LoRA adapter for each individual task.
2. Apply k-means clustering to group tasks based on these adapters.
3. Train a LoRA adapter for each resulting cluster.
For more details, see the Arrow paper: https://huggingface.co/papers/2405.11157
- For GKS, general adapters are loaded from:
TahaBa/phi3-mini-general-adapters/...
- These adapters were trained on English, French, and German Wikipedia data
using a causal language modeling objective with (507-token context β 5-token completion) pairs.
- This setup encodes general knowledge into the LoRA space, which can then be
subtracted from task-specific adapters during inference to isolate and purify them.
For more details, see the GenKnowSub paper: https://huggingface.co/papers/2505.10939
- `evaluate_on_multi_choice_batched` handles tokenization, masking context tokens,
and computing per-choice log-likelihoods for fair comparison.
- Accuracy is printed at the end for the selected dataset.
This script is mainly meant for demonstration purposes and lightweight evaluation,
not full-scale benchmarking (batch size / max length can be tuned).
=======================================================================================
Results (evaluated with microsoft/Phi-3-mini-4k-instruct, 4-bit quantization):
| Dataset | Base Acc. | Arrow Acc. | Arrow+GKS Acc. |
|--------------|-----------|------------|----------------|
| ARC-Challenge| 0.4515 | 0.5418 | 0.5585 |
| ARC-Easy | 0.6894 | 0.8404 | 0.8473 |
| Winogrande | 0.5769 | 0.6550 | 0.6724 |
| BoolQ | 0.8146 | 0.8030 | 0.8247 |
| OpenBookQA | 0.43 | 0.448 | 0.472 |
| HellaSwag | 0.7318 | 0.7150 | 0.7376 |
Observations:
- Arrow generally improves over the base model by routing tokens to the most relevant task adapters.
- Applying GKS (general knowledge subtraction) consistently gives further gains compared to Arrow and Base.
These numbers are not meant as leaderboard results, but as a sanity check
to verify that the implementation works as expected and demonstrates
the benefits of Arrow and GenKnowSub.
"""
import argparse
import random
import numpy as np
import torch
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import ArrowConfig, create_arrow_model
MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
MODEL_MAX_LEN = 2048
def parse_args():
parser = argparse.ArgumentParser(description="Training script with strategy selection")
parser.add_argument(
"--strategy",
type=str,
choices=["base", "arrow", "gks"],
default="base",
help="Training strategy to use: base, arrow, or gks",
)
parser.add_argument(
"--ds_name",
type=str,
choices=["boolq", "hswag", "arc-easy", "arc-challenge", "oqa", "wg"],
default="arc-challenge",
help="Dataset to use: boolq, hswag, arc-easy, arc-challenge, oqa, wg",
)
return parser.parse_args()
def read_test_dataset(ds_name):
if ds_name == "boolq":
ds = load_dataset("google/boolq", split="validation", trust_remote_code=True)
elif ds_name == "hswag":
ds = load_dataset("Rowan/hellaswag", split="validation", trust_remote_code=True)
elif ds_name == "arc-challenge":
ds = load_dataset("allenai/ai2_arc", "ARC-Challenge", split="validation", trust_remote_code=True)
elif ds_name == "arc-easy":
ds = load_dataset("allenai/ai2_arc", "ARC-Easy", split="validation", trust_remote_code=True)
elif ds_name == "oqa":
ds = load_dataset("allenai/openbookqa", split="validation", trust_remote_code=True)
elif ds_name == "wg":
ds = load_dataset("allenai/winogrande", "winogrande_xl", split="validation", trust_remote_code=True)
else:
raise f"Dataset {ds_name} is not supported yet."
return ds
def extract_input_content(ds_name, row):
if ds_name == "boolq":
return f"[passage]{row['passage']}[question]{row['question']}"
if ds_name == "hswag":
return row["ctx"]
if (ds_name == "arc-challenge") or (ds_name == "arc-easy"):
return row["question"]
if ds_name == "oqa":
return row["question_stem"]
if ds_name == "wg":
return row["sentence"]
def create_multi_choice_options(row, ds_name):
options_texts = []
content = extract_input_content(ds_name, row)
if ds_name == "boolq":
choices = ["true", "false"]
if ds_name == "hswag":
choices = row["endings"]
if (ds_name == "arc-challenge") or (ds_name == "arc-easy"):
choices = row["choices"]["text"]
if ds_name == "wg":
choices = [row["option1"], row["option2"]]
if ds_name == "oqa":
choices = row["choices"]["text"]
for choice in choices:
options_texts.append(f"<|user|>\n{content}<|end|>\n<|assistant|>{choice}<|end|>\n")
return options_texts
def extract_multi_choice_target_index(row, ds_name):
if ds_name == "boolq":
return 0 if row["answer"] is True else 1
if ds_name == "hswag":
return int(row["label"])
if (ds_name == "arc-challenge") or (ds_name == "arc-easy"):
return row["choices"]["label"].index(row["answerKey"])
if ds_name == "wg":
return int(row["answer"]) - 1
if ds_name == "oqa":
return row["choices"]["label"].index(row["answerKey"])
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def compute_loglike_loss(logits, labels, reduction="none"):
bs = logits.size(0)
vocab_size = logits.size(-1)
labels = labels.squeeze(-1)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss(reduction=reduction)
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
# reshape back
if reduction == "none":
loss = loss.view((bs, -1))
non_zero_loss = (loss != 0).sum(dim=-1)
non_zero_loss[non_zero_loss == 0] = 1
loss = loss.sum(dim=-1) / non_zero_loss
return loss.float() # Convert to float32 before returning
def evaluate_on_multi_choice_batched(
eval_dataset, model, tokenizer, ds_name, labels, predictions, args, batch_size=32, max_length=512, device="cuda"
):
# Local import to mirror your original function
model.eval()
for start in tqdm(
range(0, len(eval_dataset), batch_size), total=(len(eval_dataset) + batch_size - 1) // batch_size
):
rows = [eval_dataset[i] for i in range(start, min(start + batch_size, len(eval_dataset)))]
# Build the flattened option texts for this batch
all_texts = []
options_per_sample = [] # number of options for each sample
ctx_lens_per_option = [] # context length replicated per option
for row in rows:
# options: ["<|user|>...<|assistant|>choiceA<|end|>", ...]
options = create_multi_choice_options(row, ds_name)
options_per_sample.append(len(options))
# compute context length once per sample (align with your -1 shift)
content = extract_input_content(ds_name, row)
context_prompt = f"<|user|>\n{content}<|end|>\n<|assistant|>"
ctx_len = len(tokenizer.encode(context_prompt)) - 1
all_texts.extend(options)
ctx_lens_per_option.extend([ctx_len] * len(options))
# collect gold label
labels.append(extract_multi_choice_target_index(row, ds_name))
# Tokenize all options in one go
tokenized = tokenizer(
all_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
tokenized = {k: v.to(device) for k, v in tokenized.items()}
# Create masked labels: ignore context and padding
masked_labels = tokenized["input_ids"].clone()
for i, ctx_len in enumerate(ctx_lens_per_option):
masked_labels[i, :ctx_len] = -100
masked_labels[tokenized["attention_mask"] == 0] = -100
with torch.no_grad():
logits = model(input_ids=tokenized["input_ids"], attention_mask=tokenized["attention_mask"]).logits
# per-sequence losses
losses = compute_loglike_loss(logits, masked_labels, reduction="none").detach().cpu()
# Reduce per sample (argmin across its options)
idx = 0
for n_opt in options_per_sample:
pred = torch.argmin(losses[idx : idx + n_opt]).item()
predictions.append(pred)
idx += n_opt
print(
f"Accuracy for dataset {args.ds_name} and strategy {args.strategy} is: {accuracy_score(labels, predictions)}"
)
if __name__ == "__main__":
args = parse_args()
print(f"Selected strategy: {args.strategy}")
print(f"Dataset name: {args.ds_name}")
# Loading the tokeniser
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
use_fast=True,
padding_side="right",
model_max_length=MODEL_MAX_LEN,
)
# Quantisation config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=False,
)
# Loading the model
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config,
)
# Loading the test dataset
test_dataset = read_test_dataset(args.ds_name)
print(f"{args.ds_name} is loaded with size: {len(test_dataset)}.")
labels, predictions = [], []
if args.strategy == "base":
# Batch-wise inference
with torch.no_grad():
evaluate_on_multi_choice_batched(
test_dataset,
base_model,
tokenizer,
args.ds_name,
labels,
predictions,
args,
batch_size=64, # tune this
max_length=512, # tune if options are long
device="cuda",
)
else:
general_adapter_paths = []
if args.strategy == "gks":
arrow_config = ArrowConfig(
top_k=3,
router_temperature=1.0,
use_gks=True,
)
# General adapter paths from the hub
general_adapter_paths = [
"TahaBa/phi3-mini-general-adapters/cluster0_batch16_prop1.0_langen/checkpoint-17",
"TahaBa/phi3-mini-general-adapters/cluster0_batch16_prop1.0_langfr/checkpoint-35",
"TahaBa/phi3-mini-general-adapters/cluster0_batch16_prop1.0_langger/checkpoint-17",
]
else:
arrow_config = ArrowConfig(
top_k=3,
router_temperature=1.0,
)
# Task-specific adapter paths from the hub
task_specific_adapter_paths = [f"TahaBa/phi3-mini-clustered-flan/ts_expert_{i}" for i in range(10)]
# Creating the Arrow model
model = create_arrow_model(
base_model=base_model,
task_specific_adapter_paths=task_specific_adapter_paths,
general_adapter_paths=general_adapter_paths,
arrow_config=arrow_config,
)
# Batch-wise inference
with torch.no_grad():
evaluate_on_multi_choice_batched(
test_dataset,
model,
tokenizer,
args.ds_name,
labels,
predictions,
args,
batch_size=32, # tune this
max_length=512, # tune if options are long
device="cuda",
)
|