wanicca commited on
Commit
7464087
1 Parent(s): 4cee950

Add gradio demo

Browse files
Files changed (4) hide show
  1. 20B_tokenizer.json +0 -0
  2. README.md +1 -1
  3. app.py +153 -0
  4. requirements.txt +5 -0
20B_tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -7,7 +7,7 @@ sdk: gradio
7
  sdk_version: 3.27.0
8
  app_file: app.py
9
  pinned: false
10
- license: wtfpl
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
7
  sdk_version: 3.27.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import argparse
3
+ import os, gc, torch
4
+ from datetime import datetime
5
+ from huggingface_hub import hf_hub_download
6
+ from pynvml import *
7
+ nvmlInit()
8
+ gpu_h = nvmlDeviceGetHandleByIndex(0)
9
+ ctx_limit = 4096
10
+ desc = f'''链接:<a href='https://github.com/BlinkDL/ChatRWKV' target="_blank" style="margin:0 0.5em">ChatRWKV</a><a href='https://github.com/BlinkDL/RWKV-LM' target="_blank" style="margin:0 0.5em">RWKV-LM</a><a href="https://pypi.org/project/rwkv/" target="_blank" style="margin:0 0.5em">RWKV pip package</a><a href="https://zhuanlan.zhihu.com/p/618011122" target="_blank" style="margin:0 0.5em">知乎教程</a>
11
+ '''
12
+
13
+ parser = argparse.ArgumentParser(prog = 'ChatGal RWKV')
14
+ parser.add_argument('--share',action='store_true')
15
+ args = parser.parse_args()
16
+
17
+ os.environ["RWKV_JIT_ON"] = '1'
18
+
19
+ from rwkv.model import RWKV
20
+ model_path = hf_hub_download(repo_id="Synthia/ChatGalRWKV/", filename="RWKV-4-Novel-7B-v1-Chn-20230409-ctx4096.pth")
21
+ if os.environ['ON_COLAB'] == '1':
22
+ os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
23
+ model = RWKV(model=model_path, strategy='cuda bf16')
24
+ else:
25
+ model = RWKV(model=model_path, strategy='cpu bf16')
26
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
27
+ pipeline = PIPELINE(model, "20B_tokenizer.json")
28
+
29
+ def infer(
30
+ ctx,
31
+ token_count=10,
32
+ temperature=0.7,
33
+ top_p=1.0,
34
+ presencePenalty = 0.05,
35
+ countPenalty = 0.05,
36
+ ):
37
+ args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
38
+ alpha_frequency = countPenalty,
39
+ alpha_presence = presencePenalty,
40
+ token_ban = [0], # ban the generation of some tokens
41
+ token_stop = []) # stop generation whenever you see any token here
42
+
43
+ ctx = ctx.strip().split('\n')
44
+ for c in range(len(ctx)):
45
+ ctx[c] = ctx[c].strip().strip('\u3000').strip('\r')
46
+ ctx = list(filter(lambda c: c != '', ctx))
47
+ ctx = '\n' + ('\n'.join(ctx)).strip()
48
+ if ctx == '':
49
+ ctx = '\n'
50
+
51
+ # gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
52
+ # print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}',flush=True)
53
+
54
+ all_tokens = []
55
+ out_last = 0
56
+ out_str = ''
57
+ occurrence = {}
58
+ state = None
59
+ for i in range(int(token_count)):
60
+ out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
61
+ for n in args.token_ban:
62
+ out[n] = -float('inf')
63
+ for n in occurrence:
64
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
65
+
66
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
67
+ if token in args.token_stop:
68
+ break
69
+ all_tokens += [token]
70
+ if token not in occurrence:
71
+ occurrence[token] = 1
72
+ else:
73
+ occurrence[token] += 1
74
+
75
+ tmp = pipeline.decode(all_tokens[out_last:])
76
+ if '\ufffd' not in tmp:
77
+ out_str += tmp
78
+ yield out_str
79
+ out_last = i + 1
80
+ gc.collect()
81
+ torch.cuda.empty_cache()
82
+ yield out_str
83
+
84
+ examples = [
85
+ ["""女招待: 欢迎光临。您远道而来,想必一定很累了吧?
86
+
87
+ 深见: 不会……空气也清爽,也让我焕然一新呢
88
+
89
+ 女招待: 是吗。那真是太好了
90
+
91
+ 我因为撰稿的需要,而造访了这间位于信州山间的温泉宿驿。""", 200, 0.7, 1.0, 0.05, 0.05],
92
+ ["翡翠: 欢迎回来,志贵少爷。", 200, 0.7, 1.0, 0.05, 0.05],
93
+ ["""莲华: 你的目的,就是这个万华镜吧?
94
+
95
+ 莲华拿出了万华镜。
96
+
97
+ 深见: 啊……
98
+
99
+ 好像被万华镜拽过去了一般,我的腿不由自主地向它迈去
100
+
101
+ 深见: 是这个……就是这个啊……
102
+
103
+ 烨烨生辉的魔法玩具。
104
+ 连接现实与梦之世界的、诱惑的桥梁。
105
+
106
+ 深见: 请让我好好看看……
107
+
108
+ 我刚想把手伸过去,莲华就一下子把它收了回去。""", 200, 0.7, 1.0, 0.05, 0.05],
109
+ ["""嘉祥: 偶尔来一次也不错。
110
+
111
+ 我坐到客厅的沙发上,拍了拍自己的大腿。
112
+
113
+ 巧克力&香草: 喵喵?
114
+
115
+ 巧克力: 咕噜咕噜咕噜~♪人家最喜欢让主人掏耳朵了~♪
116
+
117
+ 巧克力: 主人好久都没有帮我们掏耳朵了,现在人家超兴奋的喵~♪
118
+
119
+ 香草: 身为猫娘饲主,这点服务也是应该的对吧?
120
+
121
+ 香草: 老实说我也有点兴奋呢咕噜咕噜咕噜~♪
122
+
123
+ 我摸摸各自占据住我左右两腿的两颗猫头。
124
+
125
+ 嘉祥: 开心归开心,拜托你们俩别一直乱动啊,很危险的。""", 200, 0.7, 1.0, 0.05, 0.05],
126
+ ]
127
+
128
+ iface = gr.Interface(
129
+ fn=infer,
130
+ description=f'''这是纯网文模型,去除了英文和代码能力,但写小白文更强。<b>请点击例子(在页面底部)</b>,可编辑内容。这里只看输入的最后约1200字,请写好,标点规范,无错别字,否则电脑会模仿你的错误。<b>为避免占用资源,每次生成限制长度。可将输出内容复制到输入,然后继续生成</b>。推荐提高temp改善文采,降低topp改善逻辑,提高两个penalty避免重复,具体幅度请自己实验。{desc}''',
131
+ allow_flagging="never",
132
+ inputs=[
133
+ gr.Textbox(lines=10, label="Prompt 输入的前文", value="通过基因改造,修真"), # prompt
134
+ gr.Slider(10, 200, step=10, value=200, label="token_count 每次生成的长度"), # token_count
135
+ gr.Slider(0.2, 2.0, step=0.1, value=0.7, label="temperature 默认0.7,高则变化丰富,低则保守求稳"), # temperature
136
+ gr.Slider(0.0, 1.0, step=0.05, value=1.0, label="top_p 默认1.0,高则标新立异,低则循规蹈矩"), # top_p
137
+ gr.Slider(0.0, 1.0, step=0.1, value=0.05, label="presencePenalty 默认0.05,避免写过的类似字"), # presencePenalty
138
+ gr.Slider(0.0, 1.0, step=0.1, value=0.05, label="countPenalty 默认0.05,额外避免写过多次的类似字"), # countPenalty
139
+ ],
140
+ outputs=gr.Textbox(label="Output 输出的续写", lines=28),
141
+ examples=examples,
142
+ cache_examples=False,
143
+ ).queue()
144
+
145
+ demo = gr.TabbedInterface(
146
+ [iface], ["Generative"]
147
+ )
148
+
149
+ demo.queue(max_size=5)
150
+ if args.share:
151
+ demo.launch(share=True)
152
+ else:
153
+ demo.launch(share=False)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ninja
2
+ tokenizers
3
+ rwkv
4
+ pynvml
5
+ huggingface_hub