LinkangZhan commited on
Commit
b00538d
1 Parent(s): d04c3ba

no function

Browse files
Files changed (1) hide show
  1. app.py +57 -21
app.py CHANGED
@@ -5,23 +5,48 @@ from threading import Thread
5
  import gradio as gr
6
  import torch
7
 
8
- config = PeftConfig.from_pretrained("Junity/Genshin-World-Model", trust_remote_code=True)
9
- model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan-13B-Base", torch_dtype=torch.float32, trust_remote_code=True)
10
- model = PeftModel.from_pretrained(model, r"Junity/Genshin-World-Model", torch_dtype=torch.float32, trust_remote_code=True)
11
- tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan-13B-Base", trust_remote_code=True)
12
- history = []
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- if device == "cuda":
15
- model.cuda()
16
- model = model.half()
17
 
18
- def respond(role_name, msg, textbox):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  if textbox != '':
20
- textbox = textbox + "\n" + role_name + ":" + msg + ('。' if msg[-1] not in ['。', '!', '?'] else '') + '\n'
 
 
 
 
 
21
  yield ["", textbox]
22
  else:
23
- textbox = textbox + role_name + ":" + msg + ('。' if msg[-1] not in ['。', '!', '?'] else '') + '\n'
 
 
 
 
 
24
  yield ["", textbox]
 
 
25
  input_ids = tokenizer.encode(textbox)[-4096:]
26
  input_ids = torch.LongTensor([input_ids]).to(device)
27
  generation_config = model.generation_config
@@ -48,21 +73,32 @@ def respond(role_name, msg, textbox):
48
  textbox += new_text
49
  yield ["", textbox]
50
 
 
51
  with gr.Blocks() as demo:
52
  gr.Markdown(
53
  """
54
  ## Genshin-World-Model
55
  - 模型地址 [https://huggingface.co/Junity/Genshin-World-Model](https://huggingface.co/Junity/Genshin-World-Model)
56
  - 此模型不支持要求对方回答什么,只支持续写。
 
57
  """
58
  )
59
- with gr.Tab("聊天") as chat:
60
- role_name = gr.Textbox(label="你将扮演的角色")
61
- msg = gr.Textbox(label="输入")
 
 
 
 
62
  with gr.Row():
63
- clear = gr.Button("Clear")
64
- sub = gr.Button("Submit")
65
- textbox = gr.Textbox(interactive=False)
66
- sub.click(fn=respond, inputs=[role_name, msg, textbox], outputs=[msg, textbox])
67
- clear.click(lambda: None, None, textbox, queue=False)
68
- demo.queue().launch(server_port=6006)
 
 
 
 
 
 
5
  import gradio as gr
6
  import torch
7
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # lora_folder = ''
10
+ # model_folder = ''
11
+ #
12
+ # config = PeftConfig.from_pretrained(("Junity/Genshin-World-Model" if lora_folder == ''
13
+ # else lora_folder),
14
+ # trust_remote_code=True)
15
+ # model = AutoModelForCausalLM.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == ''
16
+ # else model_folder), torch_dtype=torch.float32, trust_remote_code=True)
17
+ # model = PeftModel.from_pretrained(model,
18
+ # ("Junity/Genshin-World-Model" if lora_folder == ''
19
+ # else lora_folder)
20
+ # , torch_dtype=torch.float32, trust_remote_code=True)
21
+ # tokenizer = AutoTokenizer.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == ''
22
+ # else model_folder),
23
+ # trust_remote_code=True)
24
+ # history = []
25
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ # if device == "cuda":
27
+ # model.cuda()
28
+ # model = model.half()
29
+
30
+
31
+ def respond(role_name, character_name, msg, textbox, temp, rep, max_len, top_p, top_k):
32
  if textbox != '':
33
+ textbox = (textbox
34
+ + "\n"
35
+ + role_name
36
+ + (":" if role_name != '' else '')
37
+ + msg
38
+ + ('。\n' if msg[-1] not in ['。', '!', '?'] else ''))
39
  yield ["", textbox]
40
  else:
41
+ textbox = (textbox
42
+ + role_name
43
+ + (":" if role_name != '' else '')
44
+ + msg
45
+ + ('。' if msg[-1] not in ['。', '!', '?', ')', '}', ':', ':', '('] else '')
46
+ + ('\n' if msg[-1] in ['。', '!', '?', ')', '}'] else ''))
47
  yield ["", textbox]
48
+ if character_name != '':
49
+ textbox += ('\n' if textbox[-1] != '\n' else '') + character_name + ':'
50
  input_ids = tokenizer.encode(textbox)[-4096:]
51
  input_ids = torch.LongTensor([input_ids]).to(device)
52
  generation_config = model.generation_config
 
73
  textbox += new_text
74
  yield ["", textbox]
75
 
76
+
77
  with gr.Blocks() as demo:
78
  gr.Markdown(
79
  """
80
  ## Genshin-World-Model
81
  - 模型地址 [https://huggingface.co/Junity/Genshin-World-Model](https://huggingface.co/Junity/Genshin-World-Model)
82
  - 此模型不支持要求对方回答什么,只支持续写。
83
+ - 目前运行不了,因为没有钱租卡。
84
  """
85
  )
86
+ with gr.Tab("创作") as chat:
87
+ role_name = gr.Textbox(label="你将扮演的角色(可留空)")
88
+ character_name = gr.Textbox(label="对方的角色(可留空)")
89
+ msg = gr.Textbox(label="你说的话")
90
+ with gr.Row():
91
+ clear = gr.ClearButton()
92
+ sub = gr.Button("Submit", variant="primary")
93
  with gr.Row():
94
+ temp = gr.Slider(minimum=0, maximum=2.0, step=0.1, value=1.5, label="温度(调大则更随机)", interactive=True)
95
+ rep = gr.Slider(minimum=0, maximum=2.0, step=0.1, value=1.0, label="对重复生成的惩罚", interactive=True)
96
+ max_len = gr.Slider(minimum=4, maximum=512, step=4, value=256, label="对方回答的最大长度", interactive=True)
97
+ top_p = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.7, label="Top-p(调大则更随机)", interactive=True)
98
+ top_k = gr.Slider(minimum=0, maximum=100, step=1, value=50, label="Top-k(调大则更随机)", interactive=True)
99
+ textbox = gr.Textbox(interactive=True, label="全部文本(可修改)")
100
+ clear.add([msg, role_name, textbox])
101
+ sub.click(fn=respond,
102
+ inputs=[role_name, character_name, msg, textbox, temp, rep, max_len, top_p, top_k],
103
+ outputs=[msg, textbox])
104
+ demo.queue().launch(server_port=6006, share=True)