Script Doesn't Run
#1
by
randomthings111
- opened
from sentence_transformers import CrossEncoder
MODEL_ID = "sigridjineth/ctxl-rerank-v2-1b-seq-cls" # or local folder
def format_prompts(query: str, instruction: str, docs: list[str]) -> list[str]:
inst = f" {instruction}" if instruction else ""
return [
"Check whether a given document contains information helpful to answer the query.\n"
f"<Document> {d}\n"
f"<Query> {query}{inst} ??"
for d in docs
]
query = "Which is a domestic animal?"
docs = ["Cats are pets.", "The moon is made of cheese.", "Dogs are loyal companions."]
ce = CrossEncoder(MODEL_ID, max_length=8192)
# Ensure original padding behavior
if ce.tokenizer.pad_token is None:
ce.tokenizer.pad_token = ce.tokenizer.eos_token
ce.tokenizer.padding_side = "left"
prompts = format_prompts(query, "", docs)
scores = ce.predict(prompts) # one logit per doc (higher = more relevant)
ranked = sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)
for s, d in ranked:
print(f"{s:.4f} | {d}")
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipython-input-3313096751.py in <cell line: 0>()
23
24 prompts = format_prompts(query, "", docs)
---> 25 scores = ce.predict(prompts) # one logit per doc (higher = more relevant)
26
27 ranked = sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)
6 frames/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
118 def decorate_context(*args, **kwargs):
119 with ctx_factory():
--> 120 return func(*args, **kwargs)
121
122 return decorate_context
/usr/local/lib/python3.12/dist-packages/sentence_transformers/cross_encoder/util.py in wrapper(self, *args, **kwargs)
66 )
67
---> 68 return func(self, *args, **kwargs)
69
70 return wrapper
/usr/local/lib/python3.12/dist-packages/sentence_transformers/cross_encoder/CrossEncoder.py in predict(self, sentences, batch_size, show_progress_bar, activation_fn, apply_softmax, convert_to_numpy, convert_to_tensor)
456 for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
457 batch = sentences[start_index : start_index + batch_size]
--> 458 features = self.tokenizer(
459 batch,
460 padding=True,
/usr/local/lib/python3.12/dist-packages/transformers/tokenization_utils_base.py in __call__(self, text, text_pair, text_target, text_pair_target, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, padding_side, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)
2936 if not self._in_target_context_manager:
2937 self._switch_to_input_mode()
-> 2938 encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
2939 if text_target is not None:
2940 self._switch_to_target_mode()
/usr/local/lib/python3.12/dist-packages/transformers/tokenization_utils_base.py in _call_one(self, text, text_pair, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, padding_side, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, split_special_tokens, **kwargs)
3024 )
3025 batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
-> 3026 return self.batch_encode_plus(
3027 batch_text_or_text_pairs=batch_text_or_text_pairs,
3028 add_special_tokens=add_special_tokens,
/usr/local/lib/python3.12/dist-packages/transformers/tokenization_utils_base.py in batch_encode_plus(self, batch_text_or_text_pairs, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, padding_side, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, split_special_tokens, **kwargs)
3225 )
3226
-> 3227 return self._batch_encode_plus(
3228 batch_text_or_text_pairs=batch_text_or_text_pairs,
3229 add_special_tokens=add_special_tokens,
/usr/local/lib/python3.12/dist-packages/transformers/tokenization_utils_fast.py in _batch_encode_plus(self, batch_text_or_text_pairs, add_special_tokens, padding_strategy, truncation_strategy, max_length, stride, is_split_into_words, pad_to_multiple_of, padding_side, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, split_special_tokens)
551 self._tokenizer.encode_special_tokens = split_special_tokens
552
--> 553 encodings = self._tokenizer.encode_batch(
554 batch_text_or_text_pairs,
555 add_special_tokens=add_special_tokens,
TypeError: TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]```