|
import gradio as gr |
|
import argparse |
|
import os, gc, torch |
|
from datetime import datetime |
|
from huggingface_hub import hf_hub_download |
|
import torch |
|
|
|
|
|
|
|
ctx_limit = 4096 |
|
desc = f'''链接:<a href='https://colab.research.google.com/drive/1J1gLMMMA8GbD9JuQt6OKmwCTl9mWU0bb?usp=sharing'>太慢了?用Colab自己部署吧</a> <br /> <a href='https://github.com/BlinkDL/ChatRWKV' target="_blank" style="margin:0 0.5em">ChatRWKV</a><a href='https://github.com/BlinkDL/RWKV-LM' target="_blank" style="margin:0 0.5em">RWKV-LM</a><a href="https://pypi.org/project/rwkv/" target="_blank" style="margin:0 0.5em">RWKV pip package</a><a href="https://zhuanlan.zhihu.com/p/618011122" target="_blank" style="margin:0 0.5em">知乎教程</a> |
|
''' |
|
|
|
parser = argparse.ArgumentParser(prog = 'ChatGal RWKV') |
|
parser.add_argument('--share',action='store_true') |
|
parser.add_argument('--ckpt',type=str,default="rwkv-loramerge-0426-v2-4096-epoch11.pth") |
|
parser.add_argument('--model_path',type=str,default=None,help="local model path") |
|
parser.add_argument('--lora', type=str, default=None, help='lora checkpoint path') |
|
parser.add_argument('--lora_alpha', type=float, default=0, help='lora alpha') |
|
parser.add_argument('--lora_layer_filter',type=str,default=None,help='layer filter. Default merge all layer. Example: "0.2*25-31"') |
|
args = parser.parse_args() |
|
os.environ["RWKV_JIT_ON"] = '1' |
|
|
|
|
|
from rwkv_lora import RWKV |
|
lora_kwargs = { |
|
"lora":args.lora, |
|
"lora_alpha":args.lora_alpha, |
|
"lora_layer_filter":args.lora_layer_filter |
|
} |
|
if args.model_path: |
|
model_path = args.model_path |
|
else: |
|
model_path = hf_hub_download(repo_id="Synthia/ChatGalRWKV", filename=args.ckpt) |
|
|
|
if torch.cuda.is_available() and torch.cuda.device_count()>0: |
|
os.environ["RWKV_JIT_ON"] = '0' |
|
os.environ["RWKV_CUDA_ON"] = '0' |
|
model = RWKV(model=model_path, strategy='cuda bf16',**lora_kwargs) |
|
else: |
|
model = RWKV(model=model_path, strategy='cpu bf16',**lora_kwargs) |
|
from utils import PIPELINE, PIPELINE_ARGS |
|
pipeline = PIPELINE(model, "20B_tokenizer.json") |
|
|
|
def infer( |
|
ctx, |
|
token_count=10, |
|
temperature=0.7, |
|
top_p=1.0, |
|
top_k=50, |
|
typical_p=1.0, |
|
presencePenalty = 0.05, |
|
countPenalty = 0.05, |
|
): |
|
args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p), top_k=int(top_k),typical_p=float(typical_p), |
|
alpha_frequency = countPenalty, |
|
alpha_presence = presencePenalty, |
|
token_ban = [0], |
|
token_stop = []) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_tokens = [] |
|
out_last = 0 |
|
out_str = '' |
|
occurrence = {} |
|
state = None |
|
for i in range(int(token_count)): |
|
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state) |
|
for n in args.token_ban: |
|
out[n] = -float('inf') |
|
for n in occurrence: |
|
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency) |
|
|
|
token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, typical_p=args.typical_p) |
|
if token in args.token_stop: |
|
break |
|
all_tokens += [token] |
|
if token not in occurrence: |
|
occurrence[token] = 1 |
|
else: |
|
occurrence[token] += 1 |
|
|
|
tmp = pipeline.decode(all_tokens[out_last:]) |
|
if '\ufffd' not in tmp: |
|
out_str += tmp |
|
yield out_str |
|
out_last = i + 1 |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
yield out_str |
|
|
|
examples = [ |
|
["""女招待: 欢迎光临。您远道而来,想必一定很累了吧? |
|
|
|
深见: 不会……空气也清爽,也让我焕然一新呢 |
|
|
|
女招待: 是吗。那真是太好了 |
|
|
|
{我因为撰稿的需要,而造访了这间位于信州山间的温泉宿驿。}""", 200, 0.7, 1.0, 0, 1.0, 0.1, 0.1], |
|
["""{我叫嘉祥,家里经营着一家点心店。 |
|
为了追求独当一面的目标,我离开了老家,开了一家名为"La Soleil"的新糕点店。 |
|
原本想独自一人打拼,却没想到,在搬家的箱子中发现了意想不到的人。 |
|
她叫巧克力,是我家的猫娘,没想到她竟然用这种方式跟了过来。} |
|
|
|
嘉祥: 别以为这样就可以蒙混过去!你在干嘛啊,巧克力! |
|
|
|
巧克力: 欸嘿嘿……那个,好、好久不见了呢,主人…… |
|
|
|
嘉祥: 昨天才在家里见过面不是吗。 |
|
|
|
巧克力: 这个……话是这么说没错啦……啊哈哈……""", 200, 0.7, 1.0, 0, 1.0, 0.1, 0.1], |
|
["""莲华: 你的目的,就是这个万华镜吧? |
|
|
|
{莲华拿出了万华镜。} |
|
|
|
深见: 啊…… |
|
|
|
{好像被万华镜拽过去了一般,我的腿不由自主地向它迈去} |
|
|
|
深见: 是这个……就是这个啊…… |
|
|
|
{烨烨生辉的魔法玩具。 |
|
连接现实与梦之世界的、诱惑的桥梁。} |
|
|
|
深见: 请让我好好看看…… |
|
|
|
{我刚想把手伸过去,莲华就一下子把它收了回去。}""", 200, 0.7, 1.0, 0, 1.0, 0.1, 0.1], |
|
["""{我叫嘉祥,有两只可爱的猫娘,名字分别是巧克力和香草。} |
|
|
|
嘉祥: 偶尔来一次也不错。 |
|
|
|
{我坐到客厅的沙发上,拍了拍自己的大腿。} |
|
|
|
巧克力&香草: 喵喵? |
|
|
|
巧克力: 咕噜咕噜咕噜~♪人家最喜欢让主人掏耳朵了~♪ |
|
|
|
巧克力: 主人好久都没有帮我们掏耳朵了,现在人家超兴奋的喵~♪ |
|
|
|
香草: 身为猫娘饲主,这点服务也是应该的对吧? |
|
|
|
香草: 老实说我也有点兴奋呢咕噜咕噜咕噜~♪ |
|
|
|
{我摸摸各自占据住我左右两腿的两颗猫头。} |
|
|
|
嘉祥: 开心归开心,拜托你们俩别一直乱动啊,很危险的。""", 200, 0.7, 1.0, 0, 1.0, 0.1, 0.1], |
|
["""{我叫嘉祥,在日本开了一家名为La Soleil的糕点店,同时也是猫娘巧克力的主人。 |
|
巧克力是非常聪明的猫娘,她去国外留学了一段时间,向Alice教授学习,拿到了计算机博士学位。 |
|
她会各种程序语言,对世界各地的风土人情都十分了解,也掌握了很多数学、物理知识。} |
|
|
|
嘉祥: 很棒啊,巧克力!你真是懂不少东西呢! |
|
|
|
巧克力: 因为巧克力是主人的最佳拍挡兼猫娘情人呀♪为了主人,巧克力会解决各种问题!""", 200, 0.7, 1.0, 0, 1.0, 0.1, 0.1], |
|
] |
|
|
|
iface = gr.Interface( |
|
fn=infer, |
|
description=f'''这是GalGame剧本续写模型(实验性质,不保证效果)。<b>请点击例子(在页面底部)</b>,可编辑内容。这里只看输入的最后约1200字,请写好,标点规范,无错别字,否则电脑会模仿你的错误。<b>为避免占用资源,每次生成限制长度。可将输出内容复制到输入,然后继续生成</b>。推荐提高temp改善文采,降低topp改善逻辑,提高两个penalty避免重复,具体幅度请自己实验。<br /> {desc}''', |
|
allow_flagging="never", |
|
inputs=[ |
|
gr.Textbox(lines=10, label="Prompt 输入的前文", value="""{我叫嘉祥,在日本开了一家名为La Soleil的糕点店,同时也是猫娘巧克力的主人。 |
|
巧克力是非常聪明的猫娘,她去国外留学了一段时间,向Alice教授学习,拿到了计算机博士学位。 |
|
她会各种程序语言,对世界各地的风土人情都十分了解,也掌握了很多数学、物理知识。} |
|
|
|
嘉祥: 很棒啊,巧克力!你真是懂不少东西呢! |
|
|
|
巧克力: 因为巧克力是主人的最佳拍挡兼猫娘情人呀♪为了主人,巧克力会解决各种问题!"""), |
|
gr.Slider(10, 2000, step=10, value=200, label="token_count 每次生成的长度"), |
|
gr.Slider(0.2, 2.0, step=0.1, value=0.7, label="temperature 默认0.7,高则变化丰富,低则保守求稳"), |
|
gr.Slider(0.0, 1.0, step=0.05, value=1.0, label="top_p 默认1.0,高则标新立异,低则循规蹈矩"), |
|
gr.Slider(0, 500, step=1, value=0, label="top_k 默认0(不过滤),0以上时高则标新立异,低则循规蹈矩"), |
|
gr.Slider(0.05, 1.0, step=0.05, value=1.0, label="typical_p 默认1.0,高则保留模型天性,低则试图贴近人类典型习惯"), |
|
gr.Slider(0.0, 1.0, step=0.1, value=0.1, label="presencePenalty 默认0.0,避免写过的类似字"), |
|
gr.Slider(0.0, 1.0, step=0.1, value=0.1, label="countPenalty 默认0.0,额外避免写过多次的类似字"), |
|
], |
|
outputs=gr.Textbox(label="Output 输出的续写", lines=28), |
|
examples=examples, |
|
cache_examples=False, |
|
).queue() |
|
|
|
demo = gr.TabbedInterface( |
|
[iface], ["Generative"] |
|
) |
|
|
|
demo.queue(max_size=5) |
|
if args.share: |
|
demo.launch(share=True,server_name="0.0.0.0",server_port=58888) |
|
else: |
|
demo.launch(share=False,server_port=58888) |