| | |
| | from __future__ import annotations |
| |
|
| | from typing import Any, Dict, List, Union |
| |
|
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
|
| |
|
| | Json = Dict[str, Any] |
| | Messages = List[Dict[str, str]] |
| |
|
| |
|
| | def _is_messages(x: Any) -> bool: |
| | return ( |
| | isinstance(x, list) |
| | and len(x) > 0 |
| | and all(isinstance(m, dict) and "role" in m and "content" in m for m in x) |
| | ) |
| |
|
| |
|
| | class EndpointHandler: |
| | """ |
| | Hugging Face Inference Endpoints custom handler. |
| | |
| | Supports both text and chat formats: |
| | |
| | Text format: |
| | {"inputs": "Hello, how are you?"} |
| | |
| | Chat format (recommended): |
| | {"inputs": [{"role": "user", "content": "Hello!"}]} |
| | or |
| | {"inputs": {"messages": [{"role": "user", "content": "Hello!"}]}} |
| | |
| | Parameters: |
| | - max_new_tokens (default: 256): Max tokens to generate |
| | - temperature (default: 0.7): Sampling temperature |
| | - top_p (default: 0.95): Nucleus sampling |
| | - repetition_penalty (default: 1.0): Penalize repetitions |
| | - return_full_text (default: False): If True, return full conversation; if False, only new tokens |
| | """ |
| |
|
| | def __init__(self, model_dir: str): |
| | self.model_dir = model_dir |
| |
|
| | |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | if self.device == "cuda": |
| | |
| | self.dtype = torch.bfloat16 |
| | else: |
| | self.dtype = torch.float32 |
| |
|
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | model_dir, |
| | trust_remote_code=True, |
| | use_fast=True, |
| | ) |
| |
|
| | |
| | if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None: |
| | self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
|
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | model_dir, |
| | trust_remote_code=True, |
| | torch_dtype=self.dtype, |
| | device_map="auto" if self.device == "cuda" else None, |
| | ) |
| |
|
| | if self.device != "cuda": |
| | self.model.to(self.device) |
| |
|
| | self.model.eval() |
| |
|
| | @torch.inference_mode() |
| | def __call__(self, data: Json) -> Union[Json, List[Json]]: |
| | inputs = data.get("inputs", "") |
| | params = data.get("parameters", {}) or {} |
| |
|
| | |
| | max_new_tokens = int(params.get("max_new_tokens", 256)) |
| | temperature = float(params.get("temperature", 0.7)) |
| | top_p = float(params.get("top_p", 0.95)) |
| | top_k = int(params.get("top_k", 0)) |
| | repetition_penalty = float(params.get("repetition_penalty", 1.0)) |
| | return_full_text = bool(params.get("return_full_text", False)) |
| |
|
| | do_sample = bool(params.get("do_sample", temperature > 0)) |
| | num_beams = int(params.get("num_beams", 1)) |
| |
|
| | def _one(item: Any) -> Json: |
| | |
| | |
| | |
| | |
| | if isinstance(item, dict) and "messages" in item: |
| | item = item["messages"] |
| |
|
| | if _is_messages(item): |
| | |
| | try: |
| | |
| | prompt = self.tokenizer.apply_chat_template( |
| | item, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | ) |
| | |
| | enc = self.tokenizer(prompt, return_tensors="pt") |
| | input_ids = enc["input_ids"] |
| | except Exception: |
| | |
| | last_user_msg = next((m["content"] for m in reversed(item) if m.get("role") == "user"), "") |
| | enc = self.tokenizer(last_user_msg, return_tensors="pt") |
| | input_ids = enc["input_ids"] |
| | else: |
| | if not isinstance(item, str): |
| | item = str(item) |
| | enc = self.tokenizer(item, return_tensors="pt") |
| | input_ids = enc["input_ids"] |
| |
|
| | input_ids = input_ids.to(self.model.device) |
| | input_len = input_ids.shape[-1] |
| |
|
| | gen_ids = self.model.generate( |
| | input_ids=input_ids, |
| | max_new_tokens=max_new_tokens, |
| | do_sample=do_sample, |
| | temperature=temperature if do_sample else None, |
| | top_p=top_p if do_sample else None, |
| | top_k=top_k if do_sample and top_k > 0 else None, |
| | num_beams=num_beams, |
| | repetition_penalty=repetition_penalty, |
| | pad_token_id=self.tokenizer.pad_token_id, |
| | eos_token_id=self.tokenizer.eos_token_id, |
| | ) |
| |
|
| | |
| | if return_full_text: |
| | text = self.tokenizer.decode(gen_ids[0], skip_special_tokens=True) |
| | else: |
| | new_tokens = gen_ids[0, input_len:] |
| | text = self.tokenizer.decode(new_tokens, skip_special_tokens=True) |
| | return {"generated_text": text} |
| |
|
| | |
| | if isinstance(inputs, list) and not _is_messages(inputs): |
| | return [_one(x) for x in inputs] |
| | else: |
| | return _one(inputs) |