Spaces:
Paused
Paused
| import torch | |
| from peft import PeftModel | |
| from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer | |
| def generate_prompt(instruction, input=None): | |
| if input: | |
| return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. | |
| ### Instruction: | |
| {instruction} | |
| ### Input: | |
| {input} | |
| ### Response:""" | |
| else: | |
| return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. | |
| ### Instruction: | |
| {instruction} | |
| ### Response:""" | |
| def load_tokenizer_and_model(base_model,adapter_model,load_8bit=False): | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| try: | |
| if torch.backends.mps.is_available(): | |
| device = "mps" | |
| except: # noqa: E722 | |
| pass | |
| tokenizer = LlamaTokenizer.from_pretrained(base_model) | |
| if device == "cuda": | |
| model = LlamaForCausalLM.from_pretrained( | |
| base_model, | |
| load_in_8bit=load_8bit, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| ) | |
| model = PeftModel.from_pretrained( | |
| model, | |
| adapter_model, | |
| torch_dtype=torch.float16, | |
| ) | |
| elif device == "mps": | |
| model = LlamaForCausalLM.from_pretrained( | |
| base_model, | |
| device_map={"": device}, | |
| torch_dtype=torch.float16, | |
| ) | |
| model = PeftModel.from_pretrained( | |
| model, | |
| adapter_model, | |
| device_map={"": device}, | |
| torch_dtype=torch.float16, | |
| ) | |
| else: | |
| model = LlamaForCausalLM.from_pretrained( | |
| base_model, device_map={"": device}, low_cpu_mem_usage=True | |
| ) | |
| model = PeftModel.from_pretrained( | |
| model, | |
| adapter_model, | |
| device_map={"": device}, | |
| ) | |
| if not load_8bit: | |
| model.half() # seems to fix bugs for some users. | |
| model.eval() | |
| if torch.__version__ >= "2": | |
| model = torch.compile(model) | |
| return tokenizer,model,device | |