kenplusplus commited on
Commit
fa02e71
1 Parent(s): d0a5cbc

use vicuna

Browse files

Signed-off-by: Lu Ken <[email protected]>

Files changed (2) hide show
  1. app.py +15 -8
  2. requirements.txt +3 -2
app.py CHANGED
@@ -1,20 +1,27 @@
1
- from transformers import AutoModel, AutoTokenizer
2
  import gradio as gr
 
3
 
4
- tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
5
- model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
 
 
6
  model = model.eval()
7
 
8
  def predict(input, history=None):
9
  if history is None:
10
  history = []
11
- response, history = model.chat(tokenizer, input, history)
12
- return history, history
 
 
 
 
 
13
 
14
 
15
  with gr.Blocks() as demo:
16
- gr.Markdown('''## ChatGLM-6B - unofficial demo
17
- Unnoficial demo of the [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B/blob/main/README_en.md) model, trained on 1T tokens of English and Chinese
18
  ''')
19
  state = gr.State([])
20
  chatbot = gr.Chatbot([], elem_id="chatbot").style(height=400)
@@ -25,4 +32,4 @@ with gr.Blocks() as demo:
25
  button = gr.Button("Generate")
26
  txt.submit(predict, [txt, state], [chatbot, state])
27
  button.click(predict, [txt, state], [chatbot, state])
28
- demo.queue().launch()
 
1
+ from transformers import AutoModel, AutoTokenizer, LlamaTokenizer, LlamaForCausalLM
2
  import gradio as gr
3
+ import torch
4
 
5
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
6
+
7
+ tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.3", trust_remote_code=True)
8
+ model = LlamaForCausalLM.from_pretrained("lmsys/vicuna-7b-v1.3", trust_remote_code=True).to(DEVICE)
9
  model = model.eval()
10
 
11
  def predict(input, history=None):
12
  if history is None:
13
  history = []
14
+ new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
15
+ bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
16
+ history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
17
+ # convert the tokens to text, and then split the responses into the right format
18
+ response = tokenizer.decode(history[0]).split("<|endoftext|>")
19
+ response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list
20
+ return response, history
21
 
22
 
23
  with gr.Blocks() as demo:
24
+ gr.Markdown('''## Confidential HuggingFace Runner
 
25
  ''')
26
  state = gr.State([])
27
  chatbot = gr.Chatbot([], elem_id="chatbot").style(height=400)
 
32
  button = gr.Button("Generate")
33
  txt.submit(predict, [txt, state], [chatbot, state])
34
  button.click(predict, [txt, state], [chatbot, state])
35
+ demo.queue().launch(share=True, server_name="0.0.0.0")
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  torch
2
- transformers==4.27.1
3
  cpm_kernels
4
- icetk
 
 
 
1
  torch
 
2
  cpm_kernels
3
+ icetk
4
+ gradio==3.50.2
5
+ accelerate