Script Doesn't Run

#1
by randomthings111 - opened
from sentence_transformers import CrossEncoder

MODEL_ID = "sigridjineth/ctxl-rerank-v2-1b-seq-cls"  # or local folder

def format_prompts(query: str, instruction: str, docs: list[str]) -> list[str]:
    inst = f" {instruction}" if instruction else ""
    return [
        "Check whether a given document contains information helpful to answer the query.\n"
        f"<Document> {d}\n"
        f"<Query> {query}{inst} ??"
        for d in docs
    ]

query = "Which is a domestic animal?"
docs = ["Cats are pets.", "The moon is made of cheese.", "Dogs are loyal companions."]

ce = CrossEncoder(MODEL_ID, max_length=8192)

# Ensure original padding behavior
if ce.tokenizer.pad_token is None:
    ce.tokenizer.pad_token = ce.tokenizer.eos_token
ce.tokenizer.padding_side = "left"

prompts = format_prompts(query, "", docs)
scores = ce.predict(prompts)  # one logit per doc (higher = more relevant)

ranked = sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)
for s, d in ranked:
    print(f"{s:.4f} | {d}")
 ---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipython-input-3313096751.py in <cell line: 0>()
     23 
     24 prompts = format_prompts(query, "", docs)
---> 25 scores = ce.predict(prompts)  # one logit per doc (higher = more relevant)
     26 
     27 ranked = sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)

6 frames/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
    118     def decorate_context(*args, **kwargs):
    119         with ctx_factory():
--> 120             return func(*args, **kwargs)
    121 
    122     return decorate_context

/usr/local/lib/python3.12/dist-packages/sentence_transformers/cross_encoder/util.py in wrapper(self, *args, **kwargs)
     66                 )
     67 
---> 68         return func(self, *args, **kwargs)
     69 
     70     return wrapper

/usr/local/lib/python3.12/dist-packages/sentence_transformers/cross_encoder/CrossEncoder.py in predict(self, sentences, batch_size, show_progress_bar, activation_fn, apply_softmax, convert_to_numpy, convert_to_tensor)
    456         for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
    457             batch = sentences[start_index : start_index + batch_size]
--> 458             features = self.tokenizer(
    459                 batch,
    460                 padding=True,

/usr/local/lib/python3.12/dist-packages/transformers/tokenization_utils_base.py in __call__(self, text, text_pair, text_target, text_pair_target, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, padding_side, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)
   2936             if not self._in_target_context_manager:
   2937                 self._switch_to_input_mode()
-> 2938             encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
   2939         if text_target is not None:
   2940             self._switch_to_target_mode()

/usr/local/lib/python3.12/dist-packages/transformers/tokenization_utils_base.py in _call_one(self, text, text_pair, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, padding_side, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, split_special_tokens, **kwargs)
   3024                 )
   3025             batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
-> 3026             return self.batch_encode_plus(
   3027                 batch_text_or_text_pairs=batch_text_or_text_pairs,
   3028                 add_special_tokens=add_special_tokens,

/usr/local/lib/python3.12/dist-packages/transformers/tokenization_utils_base.py in batch_encode_plus(self, batch_text_or_text_pairs, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, padding_side, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, split_special_tokens, **kwargs)
   3225         )
   3226 
-> 3227         return self._batch_encode_plus(
   3228             batch_text_or_text_pairs=batch_text_or_text_pairs,
   3229             add_special_tokens=add_special_tokens,

/usr/local/lib/python3.12/dist-packages/transformers/tokenization_utils_fast.py in _batch_encode_plus(self, batch_text_or_text_pairs, add_special_tokens, padding_strategy, truncation_strategy, max_length, stride, is_split_into_words, pad_to_multiple_of, padding_side, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, split_special_tokens)
    551             self._tokenizer.encode_special_tokens = split_special_tokens
    552 
--> 553         encodings = self._tokenizer.encode_batch(
    554             batch_text_or_text_pairs,
    555             add_special_tokens=add_special_tokens,

TypeError: TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]```

Sign up or log in to comment