gemma2_9b_7gb / fn.py
aka7774's picture
Upload 6 files
5653716 verified
raw
history blame
4.95 kB
import os
import torch
import json
import gc
import time
from unsloth import FastLanguageModel
from transformers import TextIteratorStreamer
from threading import Thread
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = None
model = None
default_cfg = {
'model_name': "unsloth/gemma-2-9b-it-bnb-4bit",
'dtype': None,
'instruction': None,
'inst_template': None,
'chat_template': None,
'max_length': 2400,
'max_seq_length': 2048,
'max_new_tokens': 512,
'temperature': 0.9,
'top_p': 0.95,
'top_k': 40,
'repetition_penalty': 1.2,
}
cfg = default_cfg.copy()
def load_model(model_name, dtype):
global tokenizer, model, cfg
if cfg['model_name'] == model_name and cfg['dtype'] == dtype:
return
del model
del tokenizer
model = None
tokenizer = None
gc.collect()
torch.cuda.empty_cache()
model, tokenizer = FastLanguageModel.from_pretrained(
model_name,
max_seq_length = cfg['max_seq_length'],
dtype = torch.bfloat16,
load_in_8bit = (dtype == '8bit'),
load_in_4bit = (dtype == '4bit'),
)
FastLanguageModel.for_inference(model)
cfg['model_name'] = model_name
cfg['dtype'] = dtype
def clear_config():
global cfg
cfg = default_cfg.copy()
def set_config(model_name, dtype, instruction, inst_template, chat_template, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
global cfg
load_model(model_name, dtype)
cfg.update({
'instruction': instruction,
'inst_template': inst_template,
'chat_template': chat_template,
'max_new_tokens': int(max_new_tokens),
'temperature': float(temperature),
'top_p': float(top_p),
'top_k': int(top_k),
'repetition_penalty': float(repetition_penalty),
})
return 'done.'
def set_config_args(args):
global cfg
load_model(args['model_name'], args['dtype'])
cfg.update(args)
return 'done.'
def chatinterface_to_messages(message, history):
global cfg
messages = []
if cfg['instruction']:
messages.append({'role': 'system', 'content': cfg['instruction']})
for pair in history:
[user, assistant] = pair
if user:
messages.append({'role': 'user', 'content': user})
if assistant:
messages.append({'role': 'assistant', 'content': assistant})
if message:
messages.append({'role': 'user', 'content': message})
return messages
def apply_template(messages):
global tokenizer, cfg
if cfg['chat_template']:
tokenizer.chat_template = cfg['chat_template']
if type(messages) is str:
if cfg['inst_template']:
return cfg['inst_template'].format(instruction=cfg['instruction'], input=messages)
return cfg['instruction'].format(input=messages)
if type(messages) is list:
return tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)
def chat(message, history = [], instruction = None, args = {}):
global tokenizer, model, cfg
if instruction:
cfg['instruction'] = instruction
prompt = apply_template(message)
else:
messages = chatinterface_to_messages(message, history)
prompt = apply_template(messages)
inputs = tokenizer(prompt, return_tensors="pt",
padding=True, max_length=cfg['max_length'], truncation=True).to("cuda")
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True,
)
generate_kwargs = dict(
inputs,
do_sample=True,
streamer=streamer,
num_beams=1,
)
for k in [
'max_new_tokens',
'temperature',
'top_p',
'top_k',
'repetition_penalty'
]:
if cfg[k]:
generate_kwargs[k] = cfg[k]
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
model_output = ""
for new_text in streamer:
model_output += new_text
if 'fastapi' in args:
# fastapiは差分だけを返して欲しい
yield new_text
else:
# gradioは常に全文を返して欲しい
yield model_output
def infer(message, history = [], instruction = None, args = {}):
content = ''
for s in chat(message, history, instruction, args):
content += s
return content
def numel(message, history = [], instruction = None, args = {}):
global tokenizer, model, cfg
if instruction:
cfg['instruction'] = instruction
prompt = apply_template(message)
else:
messages = chatinterface_to_messages(message, history)
prompt = apply_template(messages)
model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
return torch.numel(model_inputs['input_ids'])
load_model(cfg['model_name'], '4bit')