Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import json | |
import gc | |
import time | |
from unsloth import FastLanguageModel | |
from transformers import TextIteratorStreamer | |
from threading import Thread | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
tokenizer = None | |
model = None | |
default_cfg = { | |
'model_name': "unsloth/gemma-2-9b-it-bnb-4bit", | |
'dtype': None, | |
'instruction': None, | |
'inst_template': None, | |
'chat_template': None, | |
'max_length': 2400, | |
'max_seq_length': 2048, | |
'max_new_tokens': 512, | |
'temperature': 0.9, | |
'top_p': 0.95, | |
'top_k': 40, | |
'repetition_penalty': 1.2, | |
} | |
cfg = default_cfg.copy() | |
def load_model(model_name, dtype): | |
global tokenizer, model, cfg | |
if cfg['model_name'] == model_name and cfg['dtype'] == dtype: | |
return | |
del model | |
del tokenizer | |
model = None | |
tokenizer = None | |
gc.collect() | |
torch.cuda.empty_cache() | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name, | |
max_seq_length = cfg['max_seq_length'], | |
dtype = torch.bfloat16, | |
load_in_8bit = (dtype == '8bit'), | |
load_in_4bit = (dtype == '4bit'), | |
) | |
FastLanguageModel.for_inference(model) | |
cfg['model_name'] = model_name | |
cfg['dtype'] = dtype | |
def clear_config(): | |
global cfg | |
cfg = default_cfg.copy() | |
def set_config(model_name, dtype, instruction, inst_template, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty): | |
global cfg | |
load_model(model_name, dtype) | |
cfg.update({ | |
'instruction': instruction, | |
'inst_template': inst_template, | |
'chat_template': chat_template, | |
'max_new_tokens': int(max_new_tokens), | |
'temperature': float(temperature), | |
'top_p': float(top_p), | |
'top_k': int(top_k), | |
'repetition_penalty': float(repetition_penalty), | |
}) | |
return 'done.' | |
def set_config_args(args): | |
global cfg | |
load_model(args['model_name'], args['dtype']) | |
cfg.update(args) | |
return 'done.' | |
def chatinterface_to_messages(message, history): | |
global cfg | |
messages = [] | |
if cfg['instruction']: | |
messages.append({'role': 'system', 'content': cfg['instruction']}) | |
for pair in history: | |
[user, assistant] = pair | |
if user: | |
messages.append({'role': 'user', 'content': user}) | |
if assistant: | |
messages.append({'role': 'assistant', 'content': assistant}) | |
if message: | |
messages.append({'role': 'user', 'content': message}) | |
return messages | |
def apply_template(messages): | |
global tokenizer, cfg | |
if cfg['chat_template']: | |
tokenizer.chat_template = cfg['chat_template'] | |
if type(messages) is str: | |
if cfg['inst_template']: | |
return cfg['inst_template'].format(instruction=cfg['instruction'], input=messages) | |
return cfg['instruction'].format(input=messages) | |
if type(messages) is list: | |
return tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False) | |
def chat(message, history = [], instruction = None, args = {}): | |
global tokenizer, model, cfg | |
if instruction: | |
cfg['instruction'] = instruction | |
prompt = apply_template(message) | |
else: | |
messages = chatinterface_to_messages(message, history) | |
prompt = apply_template(messages) | |
inputs = tokenizer(prompt, return_tensors="pt", | |
padding=True, max_length=cfg['max_length'], truncation=True).to("cuda") | |
streamer = TextIteratorStreamer( | |
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True, | |
) | |
generate_kwargs = dict( | |
inputs, | |
do_sample=True, | |
streamer=streamer, | |
num_beams=1, | |
) | |
for k in [ | |
'max_new_tokens', | |
'temperature', | |
'top_p', | |
'top_k', | |
'repetition_penalty' | |
]: | |
if cfg[k]: | |
generate_kwargs[k] = cfg[k] | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
model_output = "" | |
for new_text in streamer: | |
model_output += new_text | |
if 'fastapi' in args: | |
# fastapiは差分だけを返して欲しい | |
yield new_text | |
else: | |
# gradioは常に全文を返して欲しい | |
yield model_output | |
def infer(message, history = [], instruction = None, args = {}): | |
content = '' | |
for s in chat(message, history, instruction, args): | |
content += s | |
return content | |
def numel(message, history = [], instruction = None, args = {}): | |
global tokenizer, model, cfg | |
if instruction: | |
cfg['instruction'] = instruction | |
prompt = apply_template(message) | |
else: | |
messages = chatinterface_to_messages(message, history) | |
prompt = apply_template(messages) | |
model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device) | |
return torch.numel(model_inputs['input_ids']) | |
load_model(cfg['model_name'], '4bit') | |