Linly-ChatFlow / app.py
yuhaofeng-shiba's picture
Update app.py
b35cff2
raw
history blame
1.78 kB
import torch
import gradio as gr
import argparse
from utils import load_hyperparam, load_model
from models.tokenize import Tokenizer
from models.llama import *
from generate import LmGeneration
args = None
lm_generation = None
def init_args():
global args
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
args = parser.parse_args()
args.load_model_path = './model_file/chatllama_7b.bin'
args.config_path = './config/llama_7b.json'
args.spm_model_path = './model_file/tokenizer.model'
args.batch_size = 1
args.seq_length = 1024
args.world_size = 1
args.use_int8 = False
args.top_p = 0
args.repetition_penalty_range = 1024
args.repetition_penalty_slope = 0
args.repetition_penalty = 1.15
args = load_hyperparam(args)
args.tokenizer = Tokenizer(model_path=args.spm_model_path)
args.vocab_size = args.tokenizer.sp_model.vocab_size()
def init_model():
global lm_generation
torch.set_default_tensor_type(torch.HalfTensor)
model = LLaMa(args)
# torch.set_default_tensor_type(torch.FloatTensor)
model = load_model(model, args.load_model_path)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
lm_generation = LmGeneration(model, args.tokenizer)
def chat(prompt, top_k, temperature):
args.top_k = int(top_k)
args.temperature = temperature
response = lm_generation.generate(args, [prompt])
return response[0]
if __name__ == '__main__':
init_args()
init_model()
demo = gr.Interface(
fn=chat,
inputs=["text", gr.Slider(1, 60, value=40, step=1), gr.Slider(0.1, 2.0, value=1.2, step=0.1)],
outputs="text",
)
demo.launch()