Commit
·
3aefef2
1
Parent(s):
f899c80
Remove unused __init__ method and related code
Browse filesThe __init__ patching never worked due to timing (CrossEncoder instance
is created before module is loaded). Model loading happens in to_device()
instead. Also removed unused 'types' import.
- modeling_zeranker.py +0 -23
modeling_zeranker.py
CHANGED
|
@@ -3,8 +3,6 @@ from sentence_transformers import CrossEncoder as _CE
|
|
| 3 |
import math
|
| 4 |
import logging
|
| 5 |
from typing import cast, Any
|
| 6 |
-
import types
|
| 7 |
-
|
| 8 |
|
| 9 |
import torch
|
| 10 |
from transformers.configuration_utils import PretrainedConfig
|
|
@@ -116,26 +114,6 @@ def load_model(
|
|
| 116 |
return tokenizer, model
|
| 117 |
|
| 118 |
|
| 119 |
-
# Store the original __init__ method
|
| 120 |
-
_original_init = _CE.__init__
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
| 124 |
-
logger.info("Initializing CrossEncoder with eager model loading")
|
| 125 |
-
# Call the original CrossEncoder __init__ first
|
| 126 |
-
_original_init(self, *args, **kwargs)
|
| 127 |
-
|
| 128 |
-
# Load the model immediately on instantiation
|
| 129 |
-
logger.info("Loading model on instantiation (no lazy loading)")
|
| 130 |
-
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 131 |
-
self.inner_model.eval()
|
| 132 |
-
self.inner_model.gradient_checkpointing_disable()
|
| 133 |
-
self.inner_yes_token_id = self.inner_tokenizer.encode(
|
| 134 |
-
"Yes", add_special_tokens=False
|
| 135 |
-
)[0]
|
| 136 |
-
logger.info(f"CrossEncoder initialization complete. Yes token ID: {self.inner_yes_token_id}")
|
| 137 |
-
|
| 138 |
-
|
| 139 |
def predict(
|
| 140 |
self,
|
| 141 |
query_documents: list[tuple[str, str]] | None = None,
|
|
@@ -247,7 +225,6 @@ def to_device(self: _CE, new_device: torch.device) -> None:
|
|
| 247 |
logger.info(f"Model loaded successfully. Yes token ID: {self.inner_yes_token_id}")
|
| 248 |
|
| 249 |
|
| 250 |
-
_CE.__init__ = __init__
|
| 251 |
_CE.predict = predict
|
| 252 |
_CE.to = to_device
|
| 253 |
|
|
|
|
| 3 |
import math
|
| 4 |
import logging
|
| 5 |
from typing import cast, Any
|
|
|
|
|
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
from transformers.configuration_utils import PretrainedConfig
|
|
|
|
| 114 |
return tokenizer, model
|
| 115 |
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
def predict(
|
| 118 |
self,
|
| 119 |
query_documents: list[tuple[str, str]] | None = None,
|
|
|
|
| 225 |
logger.info(f"Model loaded successfully. Yes token ID: {self.inner_yes_token_id}")
|
| 226 |
|
| 227 |
|
|
|
|
| 228 |
_CE.predict = predict
|
| 229 |
_CE.to = to_device
|
| 230 |
|