Spaces:
Runtime error
Runtime error
import gc | |
import yaml | |
import json | |
import torch | |
from transformers import GenerationConfig | |
from models import alpaca, stablelm, koalpaca, flan_alpaca, mpt | |
from models import camel, t5_vicuna, vicuna, starchat, redpajama, bloom | |
from models import baize, guanaco, falcon, kullm, replit, airoboros | |
from models import samantha_vicuna | |
from utils import get_chat_interface, get_chat_manager | |
model_infos = json.load(open("model_cards.json")) | |
def get_model_type(model_info): | |
base_url = model_info["hub(base)"] | |
ft_ckpt_url = model_info["hub(ckpt)"] | |
model_type_tmp = "alpaca" | |
if "llms/wizardlm" in base_url.lower(): | |
model_type_tmp = "wizardlm" | |
elif "chronos" in base_url.lower(): | |
model_type_tmp = "chronos" | |
elif "lazarus" in base_url.lower(): | |
model_type_tmp = "lazarus" | |
elif "samantha" in base_url.lower(): | |
model_type_tmp = "samantha-vicuna" | |
elif "airoboros" in base_url.lower(): | |
model_type_tmp = "airoboros" | |
elif "replit" in base_url.lower(): | |
model_type_tmp = "replit-instruct" | |
elif "kullm" in base_url.lower(): | |
model_type_tmp = "kullm-polyglot" | |
elif "nous-hermes" in base_url.lower(): | |
model_type_tmp = "nous-hermes" | |
elif "guanaco" in base_url.lower(): | |
model_type_tmp = "guanaco" | |
elif "wizardlm-uncensored-falcon" in base_url.lower(): | |
model_type_tmp = "wizard-falcon" | |
elif "falcon" in base_url.lower(): | |
model_type_tmp = "falcon" | |
elif "baize" in base_url.lower(): | |
model_type_tmp = "baize" | |
elif "stable-vicuna" in base_url.lower(): | |
model_type_tmp = "stable-vicuna" | |
elif "vicuna" in base_url.lower(): | |
model_type_tmp = "vicuna" | |
elif "mpt" in base_url.lower(): | |
model_type_tmp = "mpt" | |
elif "redpajama-incite-7b-instruct" in base_url.lower(): | |
model_type_tmp = "redpajama-instruct" | |
elif "redpajama" in base_url.lower(): | |
model_type_tmp = "redpajama" | |
elif "starchat" in base_url.lower(): | |
model_type_tmp = "starchat" | |
elif "camel" in base_url.lower(): | |
model_type_tmp = "camel" | |
elif "flan-alpaca" in base_url.lower(): | |
model_type_tmp = "flan-alpaca" | |
elif "openassistant/stablelm" in base_url.lower(): | |
model_type_tmp = "os-stablelm" | |
elif "stablelm" in base_url.lower(): | |
model_type_tmp = "stablelm" | |
elif "fastchat-t5" in base_url.lower(): | |
model_type_tmp = "t5-vicuna" | |
elif "koalpaca-polyglot" in base_url.lower(): | |
model_type_tmp = "koalpaca-polyglot" | |
elif "alpacagpt4" in ft_ckpt_url.lower(): | |
model_type_tmp = "alpaca-gpt4" | |
elif "alpaca" in ft_ckpt_url.lower(): | |
model_type_tmp = "alpaca" | |
elif "llama-deus" in ft_ckpt_url.lower(): | |
model_type_tmp = "llama-deus" | |
elif "vicuna-lora-evolinstruct" in ft_ckpt_url.lower(): | |
model_type_tmp = "evolinstruct-vicuna" | |
elif "alpacoom" in ft_ckpt_url.lower(): | |
model_type_tmp = "alpacoom" | |
elif "guanaco" in ft_ckpt_url.lower(): | |
model_type_tmp = "guanaco" | |
else: | |
print("unsupported model type") | |
return model_type_tmp | |
def initialize_globals(): | |
global models, tokenizers | |
models = [] | |
model_names = [ | |
"baize-7b", | |
# "evolinstruct-vicuna-13b", | |
"guanaco-7b", | |
# "nous-hermes-13b" | |
] | |
for model_name in model_names: | |
model_info = model_infos[model_name] | |
model_thumbnail_tiny = model_info["thumb-tiny"] | |
model_type = get_model_type(model_info) | |
print(model_type) | |
load_model = get_load_model(model_type) | |
model, tokenizer = load_model( | |
base=model_info["hub(base)"], | |
finetuned=model_info["hub(ckpt)"], | |
mode_cpu=False, | |
mode_mps=False, | |
mode_full_gpu=False, | |
mode_8bit=True, | |
mode_4bit=True, | |
force_download_ckpt=False | |
) | |
gen_config, gen_config_raw = get_generation_config( | |
model_info["default_gen_config"] | |
) | |
models.append( | |
{ | |
"model_name": model_name, | |
"model_thumb_tiny": model_thumbnail_tiny, | |
"model_type": model_type, | |
"model": model, | |
"tokenizer": tokenizer, | |
"gen_config": gen_config, | |
"gen_config_raw": gen_config_raw, | |
"chat_interface": get_chat_interface(model_type), | |
"chat_manager": get_chat_manager(model_type), | |
} | |
) | |
def get_load_model(model_type): | |
if model_type == "alpaca" or \ | |
model_type == "alpaca-gpt4" or \ | |
model_type == "llama-deus" or \ | |
model_type == "nous-hermes" or \ | |
model_type == "lazarus" or \ | |
model_type == "chronos" or \ | |
model_type == "wizardlm": | |
return alpaca.load_model | |
elif model_type == "stablelm" or model_type == "os-stablelm": | |
return stablelm.load_model | |
elif model_type == "koalpaca-polyglot": | |
return koalpaca.load_model | |
elif model_type == "kullm-polyglot": | |
return kullm.load_model | |
elif model_type == "flan-alpaca": | |
return flan_alpaca.load_model | |
elif model_type == "camel": | |
return camel.load_model | |
elif model_type == "t5-vicuna": | |
return t5_vicuna.load_model | |
elif model_type == "stable-vicuna": | |
return vicuna.load_model | |
elif model_type == "starchat": | |
return starchat.load_model | |
elif model_type == "mpt": | |
return mpt.load_model | |
elif model_type == "redpajama" or \ | |
model_type == "redpajama-instruct": | |
return redpajama.load_model | |
elif model_type == "vicuna": | |
return vicuna.load_model | |
elif model_type == "evolinstruct-vicuna": | |
return alpaca.load_model | |
elif model_type == "alpacoom": | |
return bloom.load_model | |
elif model_type == "baize": | |
return baize.load_model | |
elif model_type == "guanaco": | |
return guanaco.load_model | |
elif model_type == "falcon" or model_type == "wizard-falcon": | |
return falcon.load_model | |
elif model_type == "replit-instruct": | |
return replit.load_model | |
elif model_type == "airoboros": | |
return airoboros.load_model | |
elif model_type == "samantha-vicuna": | |
return samantha_vicuna.load_model | |
else: | |
return None | |
def get_generation_config(path): | |
with open(path, 'rb') as f: | |
generation_config = yaml.safe_load(f.read()) | |
generation_config = generation_config["generation_config"] | |
return GenerationConfig(**generation_config), generation_config | |
def get_constraints_config(path): | |
with open(path, 'rb') as f: | |
constraints_config = yaml.safe_load(f.read()) | |
return ConstraintsConfig(**constraints_config), constraints_config["constraints"] | |