Use torch.inference_mode() and disable gradient checkpointing
#4
by
prathamj31
- opened
- config.json +4 -1
- modeling_zeranker.py +33 -16
config.json
CHANGED
|
@@ -64,5 +64,8 @@
|
|
| 64 |
"transformers_version": "4.57.1",
|
| 65 |
"use_cache": true,
|
| 66 |
"use_sliding_window": false,
|
| 67 |
-
"vocab_size": 151936
|
|
|
|
|
|
|
|
|
|
| 68 |
}
|
|
|
|
| 64 |
"transformers_version": "4.57.1",
|
| 65 |
"use_cache": true,
|
| 66 |
"use_sliding_window": false,
|
| 67 |
+
"vocab_size": 151936,
|
| 68 |
+
"auto_map": {
|
| 69 |
+
"AutoConfig": "modeling_zeranker.ZEConfig"
|
| 70 |
+
}
|
| 71 |
}
|
modeling_zeranker.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
from sentence_transformers import CrossEncoder as _CE
|
| 2 |
|
| 3 |
import math
|
|
|
|
| 4 |
from typing import cast, Any
|
| 5 |
-
import types
|
| 6 |
-
|
| 7 |
|
| 8 |
import torch
|
| 9 |
from transformers.configuration_utils import PretrainedConfig
|
|
@@ -23,8 +22,10 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
|
| 23 |
# pyright: reportUnknownMemberType=false
|
| 24 |
# pyright: reportUnknownVariableType=false
|
| 25 |
|
|
|
|
|
|
|
| 26 |
MODEL_PATH = "zeroentropy/zerank-2"
|
| 27 |
-
PER_DEVICE_BATCH_SIZE_TOKENS =
|
| 28 |
global_device = (
|
| 29 |
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 30 |
)
|
|
@@ -74,9 +75,12 @@ def load_model(
|
|
| 74 |
if device is None:
|
| 75 |
device = global_device
|
| 76 |
|
|
|
|
|
|
|
| 77 |
config = AutoConfig.from_pretrained(MODEL_PATH)
|
| 78 |
assert isinstance(config, PretrainedConfig)
|
| 79 |
|
|
|
|
| 80 |
model = AutoModelForCausalLM.from_pretrained(
|
| 81 |
MODEL_PATH,
|
| 82 |
torch_dtype="auto",
|
|
@@ -93,6 +97,7 @@ def load_model(
|
|
| 93 |
| Qwen3ForCausalLM,
|
| 94 |
)
|
| 95 |
|
|
|
|
| 96 |
tokenizer = cast(
|
| 97 |
AutoTokenizer,
|
| 98 |
AutoTokenizer.from_pretrained(
|
|
@@ -105,6 +110,7 @@ def load_model(
|
|
| 105 |
if tokenizer.pad_token is None:
|
| 106 |
tokenizer.pad_token = tokenizer.eos_token
|
| 107 |
|
|
|
|
| 108 |
return tokenizer, model
|
| 109 |
|
| 110 |
|
|
@@ -125,13 +131,7 @@ def predict(
|
|
| 125 |
raise ValueError("query_documents or sentences must be provided")
|
| 126 |
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 127 |
|
| 128 |
-
|
| 129 |
-
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 130 |
-
self.inner_model.gradient_checkpointing_enable()
|
| 131 |
-
self.inner_model.eval()
|
| 132 |
-
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 133 |
-
"Yes", add_special_tokens=False
|
| 134 |
-
)[0]
|
| 135 |
|
| 136 |
model = self.inner_model
|
| 137 |
tokenizer = self.inner_tokenizer
|
|
@@ -161,9 +161,12 @@ def predict(
|
|
| 161 |
batches[-1].append((query, document))
|
| 162 |
max_length = max(max_length, 20 + len(query) + len(document))
|
| 163 |
|
|
|
|
|
|
|
| 164 |
# Inference all of the document batches
|
| 165 |
all_logits: list[float] = []
|
| 166 |
-
for batch in batches:
|
|
|
|
| 167 |
batch_inputs = format_pointwise_datapoints(
|
| 168 |
tokenizer,
|
| 169 |
batch,
|
|
@@ -172,11 +175,12 @@ def predict(
|
|
| 172 |
batch_inputs = batch_inputs.to(global_device)
|
| 173 |
|
| 174 |
try:
|
| 175 |
-
|
|
|
|
| 176 |
except torch.OutOfMemoryError:
|
| 177 |
-
|
| 178 |
torch.cuda.empty_cache()
|
| 179 |
-
|
| 180 |
outputs = model(**batch_inputs, use_cache=False)
|
| 181 |
|
| 182 |
# Extract the logits
|
|
@@ -199,18 +203,31 @@ def predict(
|
|
| 199 |
# Unsort by indices
|
| 200 |
scores = [score for _, score in sorted(zip(permutation, scores, strict=True))]
|
| 201 |
|
|
|
|
| 202 |
return scores
|
| 203 |
|
| 204 |
|
| 205 |
def to_device(self: _CE, new_device: torch.device) -> None:
|
| 206 |
global global_device
|
|
|
|
| 207 |
global_device = new_device
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
_CE.predict = predict
|
|
|
|
| 211 |
|
| 212 |
from transformers import Qwen3Config
|
| 213 |
|
| 214 |
ZEConfig = Qwen3Config
|
| 215 |
-
|
| 216 |
-
_CE.to = to_device
|
|
|
|
| 1 |
from sentence_transformers import CrossEncoder as _CE
|
| 2 |
|
| 3 |
import math
|
| 4 |
+
import logging
|
| 5 |
from typing import cast, Any
|
|
|
|
|
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
from transformers.configuration_utils import PretrainedConfig
|
|
|
|
| 22 |
# pyright: reportUnknownMemberType=false
|
| 23 |
# pyright: reportUnknownVariableType=false
|
| 24 |
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
MODEL_PATH = "zeroentropy/zerank-2"
|
| 28 |
+
PER_DEVICE_BATCH_SIZE_TOKENS = 10_000
|
| 29 |
global_device = (
|
| 30 |
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 31 |
)
|
|
|
|
| 75 |
if device is None:
|
| 76 |
device = global_device
|
| 77 |
|
| 78 |
+
logger.info(f"Loading model from {MODEL_PATH} on device: {device}")
|
| 79 |
+
|
| 80 |
config = AutoConfig.from_pretrained(MODEL_PATH)
|
| 81 |
assert isinstance(config, PretrainedConfig)
|
| 82 |
|
| 83 |
+
logger.info(f"Loading model with config type: {config.model_type}")
|
| 84 |
model = AutoModelForCausalLM.from_pretrained(
|
| 85 |
MODEL_PATH,
|
| 86 |
torch_dtype="auto",
|
|
|
|
| 97 |
| Qwen3ForCausalLM,
|
| 98 |
)
|
| 99 |
|
| 100 |
+
logger.info("Loading tokenizer")
|
| 101 |
tokenizer = cast(
|
| 102 |
AutoTokenizer,
|
| 103 |
AutoTokenizer.from_pretrained(
|
|
|
|
| 110 |
if tokenizer.pad_token is None:
|
| 111 |
tokenizer.pad_token = tokenizer.eos_token
|
| 112 |
|
| 113 |
+
logger.info("Model and tokenizer loaded successfully")
|
| 114 |
return tokenizer, model
|
| 115 |
|
| 116 |
|
|
|
|
| 131 |
raise ValueError("query_documents or sentences must be provided")
|
| 132 |
query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
|
| 133 |
|
| 134 |
+
logger.info(f"Starting prediction for {len(query_documents)} query-document pairs")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
model = self.inner_model
|
| 137 |
tokenizer = self.inner_tokenizer
|
|
|
|
| 161 |
batches[-1].append((query, document))
|
| 162 |
max_length = max(max_length, 20 + len(query) + len(document))
|
| 163 |
|
| 164 |
+
logger.info(f"Created {len(batches)} batches for inference")
|
| 165 |
+
|
| 166 |
# Inference all of the document batches
|
| 167 |
all_logits: list[float] = []
|
| 168 |
+
for batch_idx, batch in enumerate(batches):
|
| 169 |
+
logger.debug(f"Processing batch {batch_idx + 1}/{len(batches)} with {len(batch)} pairs")
|
| 170 |
batch_inputs = format_pointwise_datapoints(
|
| 171 |
tokenizer,
|
| 172 |
batch,
|
|
|
|
| 175 |
batch_inputs = batch_inputs.to(global_device)
|
| 176 |
|
| 177 |
try:
|
| 178 |
+
with torch.inference_mode():
|
| 179 |
+
outputs = model(**batch_inputs, use_cache=False)
|
| 180 |
except torch.OutOfMemoryError:
|
| 181 |
+
logger.warning(f"GPU OOM! Memory reserved: {torch.cuda.memory_reserved()}")
|
| 182 |
torch.cuda.empty_cache()
|
| 183 |
+
logger.info(f"GPU cache cleared. Memory reserved: {torch.cuda.memory_reserved()}")
|
| 184 |
outputs = model(**batch_inputs, use_cache=False)
|
| 185 |
|
| 186 |
# Extract the logits
|
|
|
|
| 203 |
# Unsort by indices
|
| 204 |
scores = [score for _, score in sorted(zip(permutation, scores, strict=True))]
|
| 205 |
|
| 206 |
+
logger.info(f"Prediction complete. Generated {len(scores)} scores")
|
| 207 |
return scores
|
| 208 |
|
| 209 |
|
| 210 |
def to_device(self: _CE, new_device: torch.device) -> None:
|
| 211 |
global global_device
|
| 212 |
+
logger.info(f"Changing device from {global_device} to {new_device}")
|
| 213 |
global_device = new_device
|
| 214 |
|
| 215 |
+
# Load the model now since __init__ patching doesn't work due to timing
|
| 216 |
+
# (CrossEncoder instance is created before this module is loaded)
|
| 217 |
+
if not hasattr(self, "inner_model"):
|
| 218 |
+
logger.info("Loading model during device setup (eager loading)")
|
| 219 |
+
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 220 |
+
self.inner_model.eval()
|
| 221 |
+
self.inner_model.gradient_checkpointing_disable()
|
| 222 |
+
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 223 |
+
"Yes", add_special_tokens=False
|
| 224 |
+
)[0]
|
| 225 |
+
logger.info(f"Model loaded successfully. Yes token ID: {self.inner_yes_token_id}")
|
| 226 |
+
|
| 227 |
|
| 228 |
_CE.predict = predict
|
| 229 |
+
_CE.to = to_device
|
| 230 |
|
| 231 |
from transformers import Qwen3Config
|
| 232 |
|
| 233 |
ZEConfig = Qwen3Config
|
|
|
|
|
|