LinkangZhan commited on
Commit
a6636f6
1 Parent(s): b2dcf43
Files changed (1) hide show
  1. app.py +38 -31
app.py CHANGED
@@ -1,43 +1,50 @@
1
  from peft import PeftModel, PeftConfig
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
 
4
  import gradio as gr
5
  import torch
6
 
7
  config = PeftConfig.from_pretrained("Junity/Genshin-World-Model", trust_remote_code=True)
8
- model = AutoModelForCausalLM.from_pretrained("../Baichuan/models--baichuan-inc--Baichuan-13B-Base\snapshots\Baichuan-13B-Base", trust_remote_code=True)
9
- model = PeftModel.from_pretrained(model, r"../Baichuan/r64alpha32dropout0.5loss0.007/checkpoint-5000", trust_remote_code=True)
10
- tokenizer = AutoTokenizer.from_pretrained("Junity/Genshin-World-Model", trust_remote_code=True)
11
-
12
  history = []
13
- device = "cpu"
14
-
 
 
15
 
16
- def respond(role_name, msg, chatbot, character):
17
- global history
18
- if role_name is not None:
19
- history.append(role_name + "" + msg)
20
  else:
21
- history.append(msg)
22
- total_input = []
23
- for i, message in enumerate(history[::-1]):
24
- content_tokens = tokenizer.encode(message + '\n')
25
- total_input = content_tokens + total_input
26
- if content_tokens + total_input > 4096:
27
- break
28
- total_input = total_input[-4096:]
29
- input_ids = torch.LongTensor([total_input]).to(device)
30
  generation_config = model.generation_config
31
  stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- def stream_generator():
34
- outputs = []
35
- for token in model.generate(input_ids, generation_config=stream_config):
36
- outputs.append(token.item())
37
- yield None, tokenizer.decode(outputs, skip_special_tokens=True)
38
-
39
- return stream_generator()
40
 
 
 
 
41
 
42
  with gr.Blocks() as demo:
43
  gr.Markdown(
@@ -53,7 +60,7 @@ with gr.Blocks() as demo:
53
  with gr.Row():
54
  clear = gr.Button("Clear")
55
  sub = gr.Button("Submit")
56
- chatbot = gr.Chatbot()
57
- sub.click(fn=respond, inputs=[role_name, msg, chatbot], outputs=[msg, chatbot])
58
- clear.click(lambda: None, None, chatbot, queue=False)
59
- demo.queue().launch()
 
1
  from peft import PeftModel, PeftConfig
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
4
+ from threading import Thread
5
  import gradio as gr
6
  import torch
7
 
8
  config = PeftConfig.from_pretrained("Junity/Genshin-World-Model", trust_remote_code=True)
9
+ model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan-13B-Base", torch_dtype=torch.float32, device_map="auto", trust_remote_code=True)
10
+ model = PeftModel.from_pretrained(model, r"Junity/Genshin-World-Model", torch_dtype=torch.float32, device_map="auto", trust_remote_code=True)
11
+ tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan-13B-Base", trust_remote_code=True)
 
12
  history = []
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ if device == "cuda":
15
+ model.cuda()
16
+ model = model.half()
17
 
18
+ def respond(role_name, msg, textbox):
19
+ if textbox != '':
20
+ textbox = textbox + "\n" + role_name + ":" + msg + ('。' if msg[-1] not in ['。', '!', '?'] else '') + '\n'
21
+ yield ["", textbox]
22
  else:
23
+ textbox = textbox + role_name + ":" + msg + ('。' if msg[-1] not in ['。', '!', '?'] else '') + '\n'
24
+ yield ["", textbox]
25
+ input_ids = tokenizer.encode(textbox)[-4096:]
26
+ input_ids = torch.LongTensor([input_ids]).to(device)
 
 
 
 
 
27
  generation_config = model.generation_config
28
  stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
29
+ gen_kwargs = {}
30
+ gen_kwargs.update(dict(
31
+ input_ids=input_ids,
32
+ temperature=1.0,
33
+ top_p=0.75,
34
+ repetition_penalty=1.2,
35
+ max_new_tokens=256
36
+ ))
37
+ outputs = []
38
+ print(input_ids)
39
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
40
+ gen_kwargs["streamer"] = streamer
41
 
42
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
43
+ thread.start()
 
 
 
 
 
44
 
45
+ for new_text in streamer:
46
+ textbox += new_text
47
+ yield ["", textbox]
48
 
49
  with gr.Blocks() as demo:
50
  gr.Markdown(
 
60
  with gr.Row():
61
  clear = gr.Button("Clear")
62
  sub = gr.Button("Submit")
63
+ textbox = gr.Textbox(interactive=False)
64
+ sub.click(fn=respond, inputs=[role_name, msg, textbox], outputs=[msg, textbox])
65
+ clear.click(lambda: None, None, textbox, queue=False)
66
+ demo.queue().launch(server_port=6006)