BlinkDL commited on
Commit
7c790c0
1 Parent(s): 24b9beb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -4
app.py CHANGED
@@ -1,7 +1,100 @@
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return 'Please use https://huggingface.co/spaces/yahma/rwkv-14b first :)'
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ from datetime import datetime
4
+ from huggingface_hub import hf_hub_download
5
 
6
+ title = "RWKV-4 14B fp16 ctx4096"
7
+ desc = '''Links:
8
+ <a href='https://github.com/BlinkDL/ChatRWKV' target="_blank" style="margin:0 1em">ChatRWKV</a>
9
+ <a href='https://github.com/BlinkDL/RWKV-LM' target="_blank" style="margin:0 1em">RWKV-LM</a>
10
+ <a href="https://pypi.org/project/rwkv/" target="_blank" style="margin:0 1em">RWKV pip package</a>
11
+ '''
12
 
13
+ os.environ["RWKV_JIT_ON"] = '1'
14
+ os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
15
+
16
+ from rwkv.model import RWKV
17
+ model_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-pile-169m", filename="RWKV-4-Pile-169M-20220807-8023.pth")
18
+ model = RWKV(model=model_path, strategy='cuda fp16')
19
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
20
+ pipeline = PIPELINE(model, "20B_tokenizer.json")
21
+
22
+ def infer(
23
+ ctx,
24
+ token_count=10,
25
+ temperature=1.0,
26
+ top_p=0.85,
27
+ presencePenalty = 0.1,
28
+ countPenalty = 0.1,
29
+ ):
30
+ args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
31
+ alpha_frequency = countPenalty,
32
+ alpha_presence = presencePenalty,
33
+ token_ban = [0], # ban the generation of some tokens
34
+ token_stop = []) # stop generation whenever you see any token here
35
+
36
+ ctx = ctx.strip(' ')
37
+ if ctx.endswith('\n'):
38
+ ctx = f'\n{ctx.strip()}\n'
39
+ else:
40
+ ctx = f'\n{ctx.strip()}'
41
+
42
+ all_tokens = []
43
+ out_last = 0
44
+ out_str = ''
45
+ occurrence = {}
46
+ state = None
47
+ for i in range(int(token_count)):
48
+ out, state = model.forward(pipeline.encode(ctx) if i == 0 else [token], state)
49
+ for n in args.token_ban:
50
+ out[n] = -float('inf')
51
+ for n in occurrence:
52
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
53
+
54
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
55
+ if token in args.token_stop:
56
+ break
57
+ all_tokens += [token]
58
+ if token not in occurrence:
59
+ occurrence[token] = 1
60
+ else:
61
+ occurrence[token] += 1
62
+
63
+ tmp = pipeline.decode(all_tokens[out_last:])
64
+ if '\ufffd' not in tmp:
65
+ out_str += tmp
66
+ yield out_str.strip()
67
+ out_last = i + 1
68
+ yield out_str.strip()
69
+
70
+ examples = [
71
+ ["Ask Expert\n\nQuestion:\nWhat are some good plans for world peace?\n\nExpert Full Answer:\n", 100, 1.0, 0.85, 0.1, 0.1],
72
+ ["Q & A\n\nQuestion:\nWhy is the sky blue?\n\nDetailed Expert Answer:\n", 100, 1.0, 0.85, 0.1, 0.1],
73
+ ["Expert Questions & Helpful Answers\nAsk Research Experts\nQuestion:\nCan you write a short story about an elf maiden named Julia that meets a warrior named Rallio and they go on an adventure together?\n\nFull Answer:\n", 100, 1.0, 0.85, 0.1, 0.1],
74
+ ]
75
+
76
+
77
+ iface = gr.Interface(
78
+ fn=infer,
79
+ description=f'''{desc}''',
80
+ allow_flagging="never",
81
+ inputs=[
82
+ gr.Textbox(lines=20, label="Prompt"), # prompt
83
+ gr.Slider(10, 200, step=10, value=100), # token_count
84
+ gr.Slider(0.2, 2.0, step=0.1, value=1.0), # temperature
85
+ gr.Slider(0.0, 1.0, step=0.05, value=0.85), # top_p
86
+ gr.Slider(0.0, 1.0, step=0.1, value=0.1), # presencePenalty
87
+ gr.Slider(0.0, 1.0, step=0.1, value=0.1), # countPenalty
88
+ ],
89
+ outputs=gr.Textbox(label="Generated Output", lines=35),
90
+ examples=examples,
91
+ cache_examples=False,
92
+ ).queue()
93
+
94
+ demo = gr.TabbedInterface(
95
+ [iface], ["Generative"],
96
+ title=title,
97
+ )
98
+
99
+ demo.queue()
100
+ demo.launch(share=False)