LinkangZhan commited on
Commit
c2769e9
1 Parent(s): 370dfe5

support chat

Browse files
Files changed (1) hide show
  1. app.py +46 -5
app.py CHANGED
@@ -1,11 +1,43 @@
1
  from peft import PeftModel, PeftConfig
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- # config = PeftConfig.from_pretrained("Junity/Genshin-World-Model")
6
- # model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan-13B-Base")
7
- # model = PeftModel.from_pretrained(model, "Junity/Genshin-World-Model")
8
- # tokenizer = AutoTokenizer.from_pretrained("Junity/Genshin-World-Model")
9
 
10
  with gr.Blocks() as demo:
11
  gr.Markdown(
@@ -15,4 +47,13 @@ with gr.Blocks() as demo:
15
  - 此模型不支持要求对方回答什么,只支持续写。
16
  """
17
  )
18
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
1
  from peft import PeftModel, PeftConfig
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
4
  import gradio as gr
5
+ import torch
6
+
7
+ config = PeftConfig.from_pretrained("Junity/Genshin-World-Model")
8
+ model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan-13B-Base")
9
+ model = PeftModel.from_pretrained(model, "Junity/Genshin-World-Model")
10
+ tokenizer = AutoTokenizer.from_pretrained("Junity/Genshin-World-Model")
11
+
12
+ history = []
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+
16
+ def respond(role_name, msg, chatbot, character):
17
+ global history
18
+ if role_name is not None:
19
+ history.append(role_name + ":" + msg)
20
+ else:
21
+ history.append(msg)
22
+ total_input = []
23
+ for i, message in enumerate(history[::-1]):
24
+ content_tokens = tokenizer.encode(message + '\n')
25
+ total_input = content_tokens + total_input
26
+ if content_tokens + total_input > 4096:
27
+ break
28
+ total_input = total_input[-4096:]
29
+ input_ids = torch.LongTensor([total_input]).to(device)
30
+ generation_config = model.generation_config
31
+ stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
32
+
33
+ def stream_generator():
34
+ outputs = []
35
+ for token in model.generate(input_ids, generation_config=stream_config):
36
+ outputs.append(token.item())
37
+ yield None, tokenizer.decode(outputs, skip_special_tokens=True)
38
+
39
+ return stream_generator()
40
 
 
 
 
 
41
 
42
  with gr.Blocks() as demo:
43
  gr.Markdown(
 
47
  - 此模型不支持要求对方回答什么,只支持续写。
48
  """
49
  )
50
+ with gr.Tab("聊天") as chat:
51
+ role_name = gr.Textbox(label="你将扮演的角色")
52
+ msg = gr.Textbox(label="输入")
53
+ with gr.Row():
54
+ clear = gr.Button("Clear")
55
+ sub = gr.Button("Submit")
56
+ chatbot = gr.Chatbot()
57
+ sub.click(fn=respond, inputs=[role_name, msg, chatbot], outputs=[msg, chatbot])
58
+ clear.click(lambda: None, None, chatbot, queue=False)
59
+ demo.queue().launch()