Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import logging | |
import os | |
import platform | |
import gc | |
import torch | |
import colorama | |
from ..index_func import * | |
from ..presets import * | |
from ..utils import * | |
from .base_model import BaseLLMModel | |
class ChatGLM_Client(BaseLLMModel): | |
def __init__(self, model_name, user_name="") -> None: | |
super().__init__(model_name=model_name, user=user_name) | |
import torch | |
from transformers import AutoModel, AutoTokenizer | |
global CHATGLM_TOKENIZER, CHATGLM_MODEL | |
self.deinitialize() | |
if CHATGLM_TOKENIZER is None or CHATGLM_MODEL is None: | |
system_name = platform.system() | |
model_path = None | |
if os.path.exists("models"): | |
model_dirs = os.listdir("models") | |
if model_name in model_dirs: | |
model_path = f"models/{model_name}" | |
if model_path is not None: | |
model_source = model_path | |
else: | |
model_source = f"THUDM/{model_name}" | |
CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained( | |
model_source, trust_remote_code=True | |
) | |
quantified = False | |
if "int4" in model_name: | |
quantified = True | |
model = AutoModel.from_pretrained( | |
model_source, trust_remote_code=True | |
) | |
if torch.cuda.is_available(): | |
# run on CUDA | |
logging.info("CUDA is available, using CUDA") | |
model = model.half().cuda() | |
# mps加速还存在一些问题,暂时不使用 | |
elif system_name == "Darwin" and model_path is not None and not quantified: | |
logging.info("Running on macOS, using MPS") | |
# running on macOS and model already downloaded | |
model = model.half().to("mps") | |
else: | |
logging.info("GPU is not available, using CPU") | |
model = model.float() | |
model = model.eval() | |
CHATGLM_MODEL = model | |
def _get_glm3_style_input(self): | |
history = self.history | |
query = history.pop()["content"] | |
return history, query | |
def _get_glm2_style_input(self): | |
history = [x["content"] for x in self.history] | |
query = history.pop() | |
logging.debug(colorama.Fore.YELLOW + | |
f"{history}" + colorama.Fore.RESET) | |
assert ( | |
len(history) % 2 == 0 | |
), f"History should be even length. current history is: {history}" | |
history = [[history[i], history[i + 1]] | |
for i in range(0, len(history), 2)] | |
return history, query | |
def _get_glm_style_input(self): | |
if "glm2" in self.model_name: | |
return self._get_glm2_style_input() | |
else: | |
return self._get_glm3_style_input() | |
def get_answer_at_once(self): | |
history, query = self._get_glm_style_input() | |
response, _ = CHATGLM_MODEL.chat( | |
CHATGLM_TOKENIZER, query, history=history) | |
return response, len(response) | |
def get_answer_stream_iter(self): | |
history, query = self._get_glm_style_input() | |
for response, history in CHATGLM_MODEL.stream_chat( | |
CHATGLM_TOKENIZER, | |
query, | |
history, | |
max_length=self.token_upper_limit, | |
top_p=self.top_p, | |
temperature=self.temperature, | |
): | |
yield response | |
def deinitialize(self): | |
# 释放显存 | |
global CHATGLM_MODEL, CHATGLM_TOKENIZER | |
CHATGLM_MODEL = None | |
CHATGLM_TOKENIZER = None | |
gc.collect() | |
torch.cuda.empty_cache() | |
logging.info("ChatGLM model deinitialized") | |