| from typing import Any, Dict, List | |
| import torch | |
| import transformers | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 | |
| class EndpointHandler: | |
| def __init__(self, path=""): | |
| tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| path, | |
| return_dict=True, | |
| device_map="auto", | |
| torch_dtype=dtype, | |
| trust_remote_code=True, | |
| ) | |
| generation_config = model.generation_config | |
| generation_config.max_new_tokens = 2000 | |
| generation_config.temperature = 0 | |
| generation_config.num_return_sequences = 1 | |
| generation_config.pad_token_id = tokenizer.eos_token_id | |
| generation_config.eos_token_id = tokenizer.eos_token_id | |
| self.generation_config = generation_config | |
| self.pipeline = transformers.pipeline( | |
| "text-generation", model=model, tokenizer=tokenizer | |
| ) | |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| prompt = data.pop("inputs", data) | |
| result = self.pipeline(prompt, generation_config=self.generation_config) | |
| return result |