Use torch.inference_mode() and disable gradient checkpointing

#4
by prathamj31 - opened
Files changed (2) hide show
  1. config.json +4 -1
  2. modeling_zeranker.py +33 -16
config.json CHANGED
@@ -64,5 +64,8 @@
64
  "transformers_version": "4.57.1",
65
  "use_cache": true,
66
  "use_sliding_window": false,
67
- "vocab_size": 151936
 
 
 
68
  }
 
64
  "transformers_version": "4.57.1",
65
  "use_cache": true,
66
  "use_sliding_window": false,
67
+ "vocab_size": 151936,
68
+ "auto_map": {
69
+ "AutoConfig": "modeling_zeranker.ZEConfig"
70
+ }
71
  }
modeling_zeranker.py CHANGED
@@ -1,9 +1,8 @@
1
  from sentence_transformers import CrossEncoder as _CE
2
 
3
  import math
 
4
  from typing import cast, Any
5
- import types
6
-
7
 
8
  import torch
9
  from transformers.configuration_utils import PretrainedConfig
@@ -23,8 +22,10 @@ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
23
  # pyright: reportUnknownMemberType=false
24
  # pyright: reportUnknownVariableType=false
25
 
 
 
26
  MODEL_PATH = "zeroentropy/zerank-2"
27
- PER_DEVICE_BATCH_SIZE_TOKENS = 15_000
28
  global_device = (
29
  torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
30
  )
@@ -74,9 +75,12 @@ def load_model(
74
  if device is None:
75
  device = global_device
76
 
 
 
77
  config = AutoConfig.from_pretrained(MODEL_PATH)
78
  assert isinstance(config, PretrainedConfig)
79
 
 
80
  model = AutoModelForCausalLM.from_pretrained(
81
  MODEL_PATH,
82
  torch_dtype="auto",
@@ -93,6 +97,7 @@ def load_model(
93
  | Qwen3ForCausalLM,
94
  )
95
 
 
96
  tokenizer = cast(
97
  AutoTokenizer,
98
  AutoTokenizer.from_pretrained(
@@ -105,6 +110,7 @@ def load_model(
105
  if tokenizer.pad_token is None:
106
  tokenizer.pad_token = tokenizer.eos_token
107
 
 
108
  return tokenizer, model
109
 
110
 
@@ -125,13 +131,7 @@ def predict(
125
  raise ValueError("query_documents or sentences must be provided")
126
  query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
127
 
128
- if not hasattr(self, "inner_model"):
129
- self.inner_tokenizer, self.inner_model = load_model(global_device)
130
- self.inner_model.gradient_checkpointing_enable()
131
- self.inner_model.eval()
132
- self.inner_yes_token_id = self.inner_tokenizer.encode(
133
- "Yes", add_special_tokens=False
134
- )[0]
135
 
136
  model = self.inner_model
137
  tokenizer = self.inner_tokenizer
@@ -161,9 +161,12 @@ def predict(
161
  batches[-1].append((query, document))
162
  max_length = max(max_length, 20 + len(query) + len(document))
163
 
 
 
164
  # Inference all of the document batches
165
  all_logits: list[float] = []
166
- for batch in batches:
 
167
  batch_inputs = format_pointwise_datapoints(
168
  tokenizer,
169
  batch,
@@ -172,11 +175,12 @@ def predict(
172
  batch_inputs = batch_inputs.to(global_device)
173
 
174
  try:
175
- outputs = model(**batch_inputs, use_cache=False)
 
176
  except torch.OutOfMemoryError:
177
- print(f"GPU OOM! {torch.cuda.memory_reserved()}")
178
  torch.cuda.empty_cache()
179
- print(f"GPU After OOM Cache Clear: {torch.cuda.memory_reserved()}")
180
  outputs = model(**batch_inputs, use_cache=False)
181
 
182
  # Extract the logits
@@ -199,18 +203,31 @@ def predict(
199
  # Unsort by indices
200
  scores = [score for _, score in sorted(zip(permutation, scores, strict=True))]
201
 
 
202
  return scores
203
 
204
 
205
  def to_device(self: _CE, new_device: torch.device) -> None:
206
  global global_device
 
207
  global_device = new_device
208
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  _CE.predict = predict
 
211
 
212
  from transformers import Qwen3Config
213
 
214
  ZEConfig = Qwen3Config
215
-
216
- _CE.to = to_device
 
1
  from sentence_transformers import CrossEncoder as _CE
2
 
3
  import math
4
+ import logging
5
  from typing import cast, Any
 
 
6
 
7
  import torch
8
  from transformers.configuration_utils import PretrainedConfig
 
22
  # pyright: reportUnknownMemberType=false
23
  # pyright: reportUnknownVariableType=false
24
 
25
+ logger = logging.getLogger(__name__)
26
+
27
  MODEL_PATH = "zeroentropy/zerank-2"
28
+ PER_DEVICE_BATCH_SIZE_TOKENS = 10_000
29
  global_device = (
30
  torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
31
  )
 
75
  if device is None:
76
  device = global_device
77
 
78
+ logger.info(f"Loading model from {MODEL_PATH} on device: {device}")
79
+
80
  config = AutoConfig.from_pretrained(MODEL_PATH)
81
  assert isinstance(config, PretrainedConfig)
82
 
83
+ logger.info(f"Loading model with config type: {config.model_type}")
84
  model = AutoModelForCausalLM.from_pretrained(
85
  MODEL_PATH,
86
  torch_dtype="auto",
 
97
  | Qwen3ForCausalLM,
98
  )
99
 
100
+ logger.info("Loading tokenizer")
101
  tokenizer = cast(
102
  AutoTokenizer,
103
  AutoTokenizer.from_pretrained(
 
110
  if tokenizer.pad_token is None:
111
  tokenizer.pad_token = tokenizer.eos_token
112
 
113
+ logger.info("Model and tokenizer loaded successfully")
114
  return tokenizer, model
115
 
116
 
 
131
  raise ValueError("query_documents or sentences must be provided")
132
  query_documents = [[sentence[0], sentence[1]] for sentence in sentences]
133
 
134
+ logger.info(f"Starting prediction for {len(query_documents)} query-document pairs")
 
 
 
 
 
 
135
 
136
  model = self.inner_model
137
  tokenizer = self.inner_tokenizer
 
161
  batches[-1].append((query, document))
162
  max_length = max(max_length, 20 + len(query) + len(document))
163
 
164
+ logger.info(f"Created {len(batches)} batches for inference")
165
+
166
  # Inference all of the document batches
167
  all_logits: list[float] = []
168
+ for batch_idx, batch in enumerate(batches):
169
+ logger.debug(f"Processing batch {batch_idx + 1}/{len(batches)} with {len(batch)} pairs")
170
  batch_inputs = format_pointwise_datapoints(
171
  tokenizer,
172
  batch,
 
175
  batch_inputs = batch_inputs.to(global_device)
176
 
177
  try:
178
+ with torch.inference_mode():
179
+ outputs = model(**batch_inputs, use_cache=False)
180
  except torch.OutOfMemoryError:
181
+ logger.warning(f"GPU OOM! Memory reserved: {torch.cuda.memory_reserved()}")
182
  torch.cuda.empty_cache()
183
+ logger.info(f"GPU cache cleared. Memory reserved: {torch.cuda.memory_reserved()}")
184
  outputs = model(**batch_inputs, use_cache=False)
185
 
186
  # Extract the logits
 
203
  # Unsort by indices
204
  scores = [score for _, score in sorted(zip(permutation, scores, strict=True))]
205
 
206
+ logger.info(f"Prediction complete. Generated {len(scores)} scores")
207
  return scores
208
 
209
 
210
  def to_device(self: _CE, new_device: torch.device) -> None:
211
  global global_device
212
+ logger.info(f"Changing device from {global_device} to {new_device}")
213
  global_device = new_device
214
 
215
+ # Load the model now since __init__ patching doesn't work due to timing
216
+ # (CrossEncoder instance is created before this module is loaded)
217
+ if not hasattr(self, "inner_model"):
218
+ logger.info("Loading model during device setup (eager loading)")
219
+ self.inner_tokenizer, self.inner_model = load_model(global_device)
220
+ self.inner_model.eval()
221
+ self.inner_model.gradient_checkpointing_disable()
222
+ self.inner_yes_token_id = self.inner_tokenizer.encode(
223
+ "Yes", add_special_tokens=False
224
+ )[0]
225
+ logger.info(f"Model loaded successfully. Yes token ID: {self.inner_yes_token_id}")
226
+
227
 
228
  _CE.predict = predict
229
+ _CE.to = to_device
230
 
231
  from transformers import Qwen3Config
232
 
233
  ZEConfig = Qwen3Config