Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer | |
| import torch, transformers | |
| from threading import Thread | |
| import time | |
| #Load the model | |
| model_id = 'mobiuslabsgmbh/Llama-2-7b-chat-hf_1bitgs8_hqq' | |
| model = HQQModelForCausalLM.from_quantized(model_id, adapter='adapter_v0.1.lora', device='cuda') | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| #Setup Inference Mode | |
| tokenizer.add_bos_token = False | |
| tokenizer.add_eos_token = False | |
| if not tokenizer.pad_token: tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
| model.config.use_cache = True | |
| model.eval(); | |
| # Optional: torch compile for faster inference | |
| model = torch.compile(model) | |
| def chat_processor(chat, max_new_tokens=100, do_sample=True, device='cuda'): | |
| tokenizer.use_default_system_prompt = False | |
| streamer = transformers.TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_params = dict( | |
| tokenizer("<s> [INST] " + chat + " [/INST] ", return_tensors="pt").to(device), | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| pad_token_id=tokenizer.pad_token_id, | |
| top_p=0.90 if do_sample else None, | |
| top_k=50 if do_sample else None, | |
| temperature= 0.6 if do_sample else None, | |
| num_beams=1, | |
| repetition_penalty=1.2, | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_params) | |
| t.start() | |
| #print("User: ", chat); | |
| #print("Assistant: "); | |
| #outputs = "" | |
| #for text in streamer: | |
| # outputs += text | |
| # print(text, end="", flush=True) | |
| #torch.cuda.empty_cache() | |
| return t, streamer | |
| def chat(message, history): | |
| t, stream = chat_processor(chat=message) | |
| response = "" | |
| for character in stream: | |
| if character is not None: | |
| response += character | |
| # print(character) | |
| yield response | |
| time.sleep(0.1) | |
| t.join() | |
| torch.cuda.empty_cache() | |
| gr.ChatInterface(chat).launch() |