| | from __future__ import annotations |
| |
|
| | import logging |
| | import os |
| | import platform |
| |
|
| | import gc |
| | import torch |
| | import colorama |
| |
|
| | from ..index_func import * |
| | from ..presets import * |
| | from ..utils import * |
| | from .base_model import BaseLLMModel |
| |
|
| |
|
| | class ChatGLM_Client(BaseLLMModel): |
| | def __init__(self, model_name, user_name="") -> None: |
| | super().__init__(model_name=model_name, user=user_name) |
| | import torch |
| | from transformers import AutoModel, AutoTokenizer |
| | global CHATGLM_TOKENIZER, CHATGLM_MODEL |
| | self.deinitialize() |
| | if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None: |
| | system_name = platform.system() |
| | model_path = None |
| | if os.path.exists("models"): |
| | model_dirs = os.listdir("models") |
| | if model_name in model_dirs: |
| | model_path = f"models/{model_name}" |
| | if model_path is not None: |
| | model_source = model_path |
| | else: |
| | model_source = f"THUDM/{model_name}" |
| | CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained( |
| | model_source, trust_remote_code=True |
| | ) |
| | quantified = False |
| | if "int4" in model_name: |
| | quantified = True |
| | model = AutoModel.from_pretrained( |
| | model_source, trust_remote_code=True |
| | ) |
| | if torch.cuda.is_available(): |
| | |
| | logging.info("CUDA is available, using CUDA") |
| | model = model.half().cuda() |
| | |
| | elif system_name == "Darwin" and model_path is not None and not quantified: |
| | logging.info("Running on macOS, using MPS") |
| | |
| | model = model.half().to("mps") |
| | else: |
| | logging.info("GPU is not available, using CPU") |
| | model = model.float() |
| | model = model.eval() |
| | CHATGLM_MODEL = model |
| |
|
| | def _get_glm3_style_input(self): |
| | history = self.history |
| | query = history.pop()["content"] |
| | return history, query |
| |
|
| | def _get_glm2_style_input(self): |
| | history = [x["content"] for x in self.history] |
| | query = history.pop() |
| | logging.debug(colorama.Fore.YELLOW + |
| | f"{history}" + colorama.Fore.RESET) |
| | assert ( |
| | len(history) % 2 == 0 |
| | ), f"History should be even length. current history is: {history}" |
| | history = [[history[i], history[i + 1]] |
| | for i in range(0, len(history), 2)] |
| | return history, query |
| |
|
| | def _get_glm_style_input(self): |
| | if "glm2" in self.model_name: |
| | return self._get_glm2_style_input() |
| | else: |
| | return self._get_glm3_style_input() |
| |
|
| | def get_answer_at_once(self): |
| | history, query = self._get_glm_style_input() |
| | response, _ = CHATGLM_MODEL.chat( |
| | CHATGLM_TOKENIZER, query, history=history) |
| | return response, len(response) |
| |
|
| | def get_answer_stream_iter(self): |
| | history, query = self._get_glm_style_input() |
| | for response, history in CHATGLM_MODEL.stream_chat( |
| | CHATGLM_TOKENIZER, |
| | query, |
| | history, |
| | max_length=self.token_upper_limit, |
| | top_p=self.top_p, |
| | temperature=self.temperature, |
| | ): |
| | yield response |
| |
|
| | def deinitialize(self): |
| | |
| | global CHATGLM_MODEL, CHATGLM_TOKENIZER |
| | CHATGLM_MODEL = None |
| | CHATGLM_TOKENIZER = None |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | logging.info("ChatGLM model deinitialized") |
| |
|