kevinwang676 commited on
Commit
2ff7b06
1 Parent(s): a26bdb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +233 -0
app.py CHANGED
@@ -45,6 +45,239 @@ import string
45
  import argparse
46
  import json
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  from TTS.tts.utils.synthesis import synthesis
49
  from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
50
  try:
 
45
  import argparse
46
  import json
47
 
48
+ import gc, copy
49
+ from datetime import datetime
50
+ from huggingface_hub import hf_hub_download
51
+ from pynvml import *
52
+ nvmlInit()
53
+ gpu_h = nvmlDeviceGetHandleByIndex(0)
54
+ ctx_limit = 1536
55
+ title = "RWKV-4-Raven-7B-v12-Eng98%-Other2%-20230521-ctx8192"
56
+
57
+ os.environ["RWKV_JIT_ON"] = '1'
58
+ os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
59
+
60
+ from rwkv.model import RWKV
61
+ model_path1 = hf_hub_download(repo_id="BlinkDL/rwkv-4-raven", filename=f"{title}.pth")
62
+ model1 = RWKV(model=model_path1, strategy='cuda fp16i8 *8 -> cuda fp16')
63
+ from rwkv.utils import PIPELINE, PIPELINE_ARGS
64
+ pipeline = PIPELINE(model1, "20B_tokenizer.json")
65
+
66
+ def generate_prompt(instruction, input=None):
67
+ instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
68
+ input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
69
+ if input:
70
+ return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
71
+ # Instruction:
72
+ {instruction}
73
+ # Input:
74
+ {input}
75
+ # Response:
76
+ """
77
+ else:
78
+ return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
79
+ # Instruction:
80
+ {instruction}
81
+ # Response:
82
+ """
83
+
84
+ def evaluate(
85
+ instruction,
86
+ input=None,
87
+ token_count=200,
88
+ temperature=1.0,
89
+ top_p=0.7,
90
+ presencePenalty = 0.1,
91
+ countPenalty = 0.1,
92
+ ):
93
+ args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
94
+ alpha_frequency = countPenalty,
95
+ alpha_presence = presencePenalty,
96
+ token_ban = [], # ban the generation of some tokens
97
+ token_stop = [0]) # stop generation whenever you see any token here
98
+
99
+ instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
100
+ input = input.strip().replace('\r\n','\n').replace('\n\n','\n')
101
+ ctx = generate_prompt(instruction, input)
102
+
103
+ all_tokens = []
104
+ out_last = 0
105
+ out_str = ''
106
+ occurrence = {}
107
+ state = None
108
+ for i in range(int(token_count)):
109
+ out, state = model1.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
110
+ for n in occurrence:
111
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
112
+
113
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
114
+ if token in args.token_stop:
115
+ break
116
+ all_tokens += [token]
117
+ if token not in occurrence:
118
+ occurrence[token] = 1
119
+ else:
120
+ occurrence[token] += 1
121
+
122
+ tmp = pipeline.decode(all_tokens[out_last:])
123
+ if '\ufffd' not in tmp:
124
+ out_str += tmp
125
+ yield out_str.strip()
126
+ out_last = i + 1
127
+
128
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
129
+ print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
130
+ del out
131
+ del state
132
+ gc.collect()
133
+ torch.cuda.empty_cache()
134
+ yield out_str.strip()
135
+
136
+ examples = [
137
+ ["Tell me about ravens.", "", 300, 1.2, 0.5, 0.4, 0.4],
138
+ ["Write a python function to mine 1 BTC, with details and comments.", "", 300, 1.2, 0.5, 0.4, 0.4],
139
+ ["Write a song about ravens.", "", 300, 1.2, 0.5, 0.4, 0.4],
140
+ ["Explain the following metaphor: Life is like cats.", "", 300, 1.2, 0.5, 0.4, 0.4],
141
+ ["Write a story using the following information", "A man named Alex chops a tree down", 300, 1.2, 0.5, 0.4, 0.4],
142
+ ["Generate a list of adjectives that describe a person as brave.", "", 300, 1.2, 0.5, 0.4, 0.4],
143
+ ["You have $100, and your goal is to turn that into as much money as possible with AI and Machine Learning. Please respond with detailed plan.", "", 300, 1.2, 0.5, 0.4, 0.4],
144
+ ]
145
+
146
+ ##########################################################################
147
+
148
+ chat_intro = '''The following is a coherent verbose detailed conversation between <|user|> and an AI girl named <|bot|>.
149
+ <|user|>: Hi <|bot|>, Would you like to chat with me for a while?
150
+ <|bot|>: Hi <|user|>. Sure. What would you like to talk about? I'm listening.
151
+ '''
152
+
153
+ def user(message, chatbot):
154
+ chatbot = chatbot or []
155
+ # print(f"User: {message}")
156
+ return "", chatbot + [[message, None]]
157
+
158
+ def alternative(chatbot, history):
159
+ if not chatbot or not history:
160
+ return chatbot, history
161
+
162
+ chatbot[-1][1] = None
163
+ history[0] = copy.deepcopy(history[1])
164
+
165
+ return chatbot, history
166
+
167
+ def chat(
168
+ prompt,
169
+ user,
170
+ bot,
171
+ chatbot,
172
+ history,
173
+ temperature=1.0,
174
+ top_p=0.8,
175
+ presence_penalty=0.1,
176
+ count_penalty=0.1,
177
+ ):
178
+ args = PIPELINE_ARGS(temperature=max(0.2, float(temperature)), top_p=float(top_p),
179
+ alpha_frequency=float(count_penalty),
180
+ alpha_presence=float(presence_penalty),
181
+ token_ban=[], # ban the generation of some tokens
182
+ token_stop=[]) # stop generation whenever you see any token here
183
+
184
+ if not chatbot:
185
+ return chatbot, history
186
+
187
+ message = chatbot[-1][0]
188
+ message = message.strip().replace('\r\n','\n').replace('\n\n','\n')
189
+ ctx = f"{user}: {message}\n\n{bot}:"
190
+
191
+ if not history:
192
+ prompt = prompt.replace("<|user|>", user.strip())
193
+ prompt = prompt.replace("<|bot|>", bot.strip())
194
+ prompt = prompt.strip()
195
+ prompt = f"\n{prompt}\n\n"
196
+
197
+ out, state = model1.forward(pipeline.encode(prompt), None)
198
+ history = [state, None, []] # [state, state_pre, tokens]
199
+ # print("History reloaded.")
200
+
201
+ [state, _, all_tokens] = history
202
+ state_pre_0 = copy.deepcopy(state)
203
+
204
+ out, state = model1.forward(pipeline.encode(ctx)[-ctx_limit:], state)
205
+ state_pre_1 = copy.deepcopy(state) # For recovery
206
+
207
+ # print("Bot:", end='')
208
+
209
+ begin = len(all_tokens)
210
+ out_last = begin
211
+ out_str: str = ''
212
+ occurrence = {}
213
+ for i in range(300):
214
+ if i <= 0:
215
+ nl_bias = -float('inf')
216
+ elif i <= 30:
217
+ nl_bias = (i - 30) * 0.1
218
+ elif i <= 130:
219
+ nl_bias = 0
220
+ else:
221
+ nl_bias = (i - 130) * 0.25
222
+ out[187] += nl_bias
223
+ for n in occurrence:
224
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
225
+
226
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
227
+ next_tokens = [token]
228
+ if token == 0:
229
+ next_tokens = pipeline.encode('\n\n')
230
+ all_tokens += next_tokens
231
+
232
+ if token not in occurrence:
233
+ occurrence[token] = 1
234
+ else:
235
+ occurrence[token] += 1
236
+
237
+ out, state = model1.forward(next_tokens, state)
238
+
239
+ tmp = pipeline.decode(all_tokens[out_last:])
240
+ if '\ufffd' not in tmp:
241
+ # print(tmp, end='', flush=True)
242
+ out_last = begin + i + 1
243
+ out_str += tmp
244
+
245
+ chatbot[-1][1] = out_str.strip()
246
+ history = [state, all_tokens]
247
+ yield chatbot, history
248
+
249
+ out_str = pipeline.decode(all_tokens[begin:])
250
+ out_str = out_str.replace("\r\n", '\n').replace('\\n', '\n')
251
+
252
+ if '\n\n' in out_str:
253
+ break
254
+
255
+ # State recovery
256
+ if f'{user}:' in out_str or f'{bot}:' in out_str:
257
+ idx_user = out_str.find(f'{user}:')
258
+ idx_user = len(out_str) if idx_user == -1 else idx_user
259
+ idx_bot = out_str.find(f'{bot}:')
260
+ idx_bot = len(out_str) if idx_bot == -1 else idx_bot
261
+ idx = min(idx_user, idx_bot)
262
+
263
+ if idx < len(out_str):
264
+ out_str = f" {out_str[:idx].strip()}\n\n"
265
+ tokens = pipeline.encode(out_str)
266
+
267
+ all_tokens = all_tokens[:begin] + tokens
268
+ out, state = model1.forward(tokens, state_pre_1)
269
+ break
270
+
271
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
272
+ print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
273
+
274
+ gc.collect()
275
+ torch.cuda.empty_cache()
276
+
277
+ chatbot[-1][1] = out_str.strip()
278
+ history = [state, state_pre_0, all_tokens]
279
+ yield chatbot, history
280
+
281
  from TTS.tts.utils.synthesis import synthesis
282
  from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
283
  try: