gaochangkuan commited on
Commit
4272568
1 Parent(s): a0c8ee9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -0
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import os
2
+ #os.environ["CUDA_VISIBLE_DEVICES"] = "0,2,3"
3
+ import torch
4
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
5
+
6
+ model_path= "CubeAI/Zhuji-Internet-Literature-Intelligent-Writing-Model-V1.0"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_path, encode_special_tokens=True)
8
+ model= AutoModelForCausalLM.from_pretrained(
9
+ model_path,
10
+ torch_dtype= torch.bfloat16,
11
+ low_cpu_mem_usage= True,
12
+ attn_implementation="flash_attention_2",
13
+ device_map= "auto"
14
+ )
15
+
16
+
17
+ model = torch.compile(model)
18
+ model = model.eval()
19
+
20
+ import gradio as gr
21
+ import os
22
+ from transformers import GemmaTokenizer, AutoModelForCausalLM
23
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
24
+ from threading import Thread
25
+
26
+ # Set an environment variable
27
+
28
+
29
+
30
+ DESCRIPTION = '''
31
+ <div>
32
+ <h1 style="text-align: center;">自研模型测试长篇小说概要</h1>
33
+ <p>本空间旨在展示我们自行研发的模型在长篇小说领域的应用能力。该模型经过特别优化,适用于长篇小说的生成和理解任务,具备两种不同的规模配置:基础版和高级版。</p>
34
+ <p>📚 如果您对模型在长篇小说创作和分析方面的应用感兴趣,欢迎尝试使用我们的基础版模型进行初步探索。</p>
35
+ <p>🚀 对于寻求更高级功能和更深层次分析的用户,我们提供了高级版模型,它具备更强大的生成能力和更精细的文本理解技术。</p>
36
+ </div>
37
+ '''
38
+
39
+ LICENSE = """
40
+ <p/>
41
+ ---
42
+ Built with NovelGen
43
+ """
44
+
45
+ PLACEHOLDER = """
46
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
47
+ <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">ai助力写作</h1>
48
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">ai辅助写作</p>
49
+ </div>
50
+ """
51
+
52
+
53
+ css = """
54
+ h1 {
55
+ text-align: center;
56
+ display: block;
57
+ }
58
+ #duplicate-button {
59
+ margin: auto;
60
+ color: white;
61
+ background: #1565c0;
62
+ border-radius: 100vh;
63
+ }
64
+ """
65
+ tokenizer.chat_template = """{% for message in messages %}
66
+ {% if message['role'] == 'user' %}
67
+ {{'<|user|>'+ message['content'].strip() + '<|observation|>'+ '<|assistant|>'}}
68
+ {% elif message['role'] == 'system' %}
69
+ {{ '<|system|>' + message['content'].strip() + '<|observation|>'}}
70
+ {% elif message['role'] == 'assistant' %}
71
+ {{ message['content'] + '<|observation|>'}}
72
+ {% endif %}
73
+ {% endfor %}""".replace("\n", "").replace(" ", "")
74
+
75
+ def chat_zhuji(
76
+ message: str,
77
+ history: list,
78
+ temperature: float,
79
+ max_new_tokens: int
80
+ ) -> str:
81
+ """
82
+ Generate a streaming response using the llama3-8b model.
83
+ Args:
84
+ message (str): The input message.
85
+ history (list): The conversation history used by ChatInterface.
86
+ temperature (float): The temperature for generating the response.
87
+ max_new_tokens (int): The maximum number of new tokens to generate.
88
+ Returns:
89
+ str: The generated response.
90
+ """
91
+ conversation = []
92
+ #<|system|><|observation|><|user|>
93
+ for user, assistant in history:
94
+ conversation.extend([{"role": "system","content": "",},{"role": "user", "content": user}, {"role": "<|assistant|>", "content": assistant}])
95
+ conversation.append({"role": "user", "content": message})
96
+
97
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
98
+
99
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
100
+
101
+ generate_kwargs = dict(
102
+ input_ids= input_ids,
103
+ streamer=streamer,
104
+ max_new_tokens=max_new_tokens,
105
+ do_sample=True,
106
+ penalty_alpha= 0.65,
107
+ top_p= 0.90,
108
+ top_k= 35,
109
+ use_cache= True,
110
+ eos_token_id= tokenizer.encode("<|observation|>",add_special_tokens= False),
111
+ temperature=temperature,
112
+ )
113
+ # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
114
+ if temperature == 0:
115
+ generate_kwargs['do_sample'] = False
116
+
117
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
118
+ t.start()
119
+
120
+ outputs = []
121
+ for text in streamer:
122
+ outputs.append(text)
123
+ yield "".join(outputs)
124
+
125
+
126
+ # Gradio block
127
+ chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
128
+ text_box= gr.Textbox(show_copy_button= True)
129
+ with gr.Blocks(fill_height=True, css=css) as demo:
130
+
131
+ #gr.Markdown(DESCRIPTION)
132
+ gr.ChatInterface(
133
+ fn=chat_zhuji,
134
+ chatbot=chatbot,
135
+ textbox= text_box,
136
+ fill_height=True,
137
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
138
+ additional_inputs=[
139
+ gr.Slider(minimum=0,
140
+ maximum=1,
141
+ step=0.1,
142
+ value=0.95,
143
+ label="Temperature",
144
+ render=False),
145
+ gr.Slider(minimum=2048,
146
+ maximum=8192*2,
147
+ step=1,
148
+ value=8192*2,
149
+ label="Max new tokens",
150
+ render=False ),
151
+ ],
152
+ examples=[
153
+ ['请给一个古代美女的外貌来一段描写'],
154
+ ['请生成4个东方神功的招式名称'],
155
+ ['生成一段官军和倭寇打斗的场面描写'],
156
+ ['生成一个都市大女主的角色档案'],
157
+ ],
158
+ cache_examples=False,
159
+ )
160
+
161
+ gr.Markdown(LICENSE)
162
+
163
+ if __name__ == "__main__":
164
+ demo.launch(
165
+ #server_name='0.0.0.0',
166
+ #server_port=config.webui_config.port,
167
+ #inbrowser=True,
168
+ share=True
169
+ )