Kirili4ik commited on
Commit
21a5dba
1 Parent(s): 5246f84

clean and make 6ep model

Browse files
Files changed (2) hide show
  1. app.py +18 -136
  2. util_funcs.py +109 -0
app.py CHANGED
@@ -1,118 +1,7 @@
1
  import torch
2
  import gradio as gr
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
-
5
-
6
- def get_length_param(text: str, tokenizer) -> str:
7
- """Maps text to 1 of 4 buckets based on length after encoding.
8
-
9
- Parameters
10
- ----------
11
- text: str
12
- The text to be given 1 of 4 length parameters.
13
-
14
- tokenizer: HuggingFace tokenizer
15
- Tokenizer that used to compute the length of the text after encoding.
16
- For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html
17
-
18
- Returns
19
- -------
20
- len_param: str
21
- One of four buckets:
22
- '1' for short, '2' for medium, '3' for long texts and '-' for all others.
23
- """
24
- tokens_count = len(tokenizer.encode(text))
25
- if tokens_count <= 15:
26
- len_param = '1'
27
- elif tokens_count <= 50:
28
- len_param = '2'
29
- elif tokens_count <= 256:
30
- len_param = '3'
31
- else:
32
- len_param = '-'
33
- return len_param
34
-
35
-
36
- def get_user_param(text: dict, machine_name_in_chat: str) -> str:
37
- """Maps text by 1/0 for it to be the person or the machine in the dialogue
38
-
39
- Parameters
40
- ----------
41
- text: Dict[..., 'from', ...]
42
- Dict containing field 'from' with the name of the user who sent the message
43
-
44
- machine_name_in_chat: str
45
- Str with the name of the machine - it will be predicted
46
- """
47
- if text['from'] == machine_name_in_chat:
48
- return '1' # machine
49
- else:
50
- return '0' # human
51
-
52
-
53
- def build_text_file(data_json: dict, dest_path: str,
54
- tokenizer, machine_name_in_chat='Кирилл Гельван'):
55
- """Create a text file for training in special format for ruDialoGPT-3.
56
-
57
- Parameters
58
- ----------
59
- data_json: dict
60
- Dict containing 'text' (message) and 'from' (user who sent the message)
61
-
62
- dest_path: str
63
- String containing path to write data there
64
-
65
- tokenizer: HuggingFace tokenizer
66
- Tokenizer that used to compute the length of the text after encoding.
67
- For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html
68
- """
69
- f = open(dest_path, 'w')
70
- new_data = ''
71
- for i in range(len(data_json) - 1):
72
- message, next_message = data_json[i], data_json[i+1]
73
- if message['text'] == '' or type(message['text']) != str:
74
- continue
75
- if next_message['text'] == '' or type(next_message['text']) != str:
76
- continue
77
-
78
- user = get_user_param(message, machine_name_in_chat=machine_name_in_chat)
79
- length = get_length_param(data_json[i+1]['text'], tokenizer)
80
- message_text = re.sub(r"\n", ". ", message['text'])
81
- new_data += f"|{user}|{length}|{message_text}{tokenizer.eos_token}" + "\n"
82
-
83
- f.write(new_data)
84
-
85
-
86
- def load_dataset(train_path, test_path, tokenizer):
87
- """Creates train and test PyTorch datasets and collate_fn using HuggingFace.
88
-
89
- Parameters
90
- ----------
91
- train_path: str
92
- String containing path to train data
93
-
94
- test_path: str
95
- String containing path to test data
96
-
97
- tokenizer: HuggingFace tokenizer
98
- Tokenizer that used to compute the length of the text after encoding.
99
- For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html
100
- """
101
- train_dataset = TextDataset(
102
- tokenizer = tokenizer,
103
- file_path = train_path,
104
- block_size = 256)
105
-
106
- test_dataset = TextDataset(
107
- tokenizer = tokenizer,
108
- file_path = test_path,
109
- block_size = 256)
110
-
111
- data_collator = DataCollatorForLanguageModeling(
112
- tokenizer=tokenizer, mlm=False
113
- )
114
- return train_dataset, test_dataset, data_collator
115
-
116
 
117
  def chat_function(message, length_of_the_answer, who_is_next, creativity): # model, tokenizer
118
 
@@ -138,12 +27,6 @@ def chat_function(message, length_of_the_answer, who_is_next, creativity): # m
138
  history = gr.get_state() or []
139
  chat_history_ids = torch.zeros((1, 0), dtype=torch.int) if history == [] else torch.tensor(history[-1][2], dtype=torch.long)
140
 
141
- ######### next_who = input("Who's phrase?\t") #input("H / G?") # Human or GPT
142
-
143
- # In case Human
144
- ##### if next_who == "H":
145
-
146
- ######## input_user = input("===> Human: ")
147
  # encode the new user input, add parameters and return a tensor in Pytorch
148
  if len(input_user) != 0:
149
 
@@ -156,7 +39,6 @@ def chat_function(message, length_of_the_answer, who_is_next, creativity): # m
156
 
157
  if next_who == "G":
158
 
159
- ######## next_len = input("Phrase len? 1/2/3/-\t") #input("Exp. len?(-/1/2/3): ")
160
  # encode the new user input, add parameters and return a tensor in Pytorch
161
  new_user_input_ids = tokenizer.encode(f"|1|{next_len}|", return_tensors="pt")
162
  # append the new user input tokens to the chat history
@@ -198,45 +80,45 @@ def chat_function(message, length_of_the_answer, who_is_next, creativity): # m
198
  html += f"<div class='resp_msg'>{resp_msg}</div>"
199
  html += "</div>"
200
  return html
201
-
 
 
202
 
203
 
204
  # Download checkpoint:
205
- checkpoint = "Kirili4ik/ruDialoGpt3-medium-finetuned-telegram"
206
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
207
  model = AutoModelForCausalLM.from_pretrained(checkpoint)
208
  model = model.eval()
209
 
210
-
211
  checkbox_group = gr.inputs.CheckboxGroup(['Kirill', 'Me'], default=['Kirill'], type="value", label=None)
212
-
213
- inputs = gr.inputs.Textbox(lines=1, label="???")
214
- outputs = gr.outputs.Textbox(label="Kirill (GPT-2):")
215
  title = "Chat with Kirill (in Russian)"
216
  description = "Тут можно поболтать со мной. Но вместо меня бот. Оставь message пустым, чтобы Кирилл продолжил говорить. Подбробнее о технике по ссылке внизу."
217
  article = "<p style='text-align: center'><a href='https://github.com/Kirili4ik/ruDialoGpt3-finetune-colab'>Github with fine-tuning GPT-2 on your chat</a></p>"
218
  examples = [
219
- ["Привет, как дела?", 'medium', 'Kirill', 0.6],
220
  ["Сколько тебе лет?", 'medium', 'Kirill', 0.3],
221
  ]
222
 
223
- iface = gr.Interface(chat_function,
224
- [
225
- "text",
226
- gr.inputs.Radio(["short", "medium", "long"], default='medium'),
227
  gr.inputs.Radio(["Kirill", "Me"], default='Kirill'),
228
- gr.inputs.Slider(0, 1, default=0.6)
229
- ],
230
- "html",
231
  title=title, description=description, article=article, examples=examples,
232
  css= """
233
  .chatbox {display:flex;flex-direction:column}
234
  .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
235
  .user_msg {background-color:cornflowerblue;color:white;align-self:start}
236
  .resp_msg {background-color:lightgray;align-self:self-end}
237
- """,
238
- allow_screenshot=True,
239
  allow_flagging=False
240
  )
241
 
242
- iface.launch()
 
 
1
  import torch
2
  import gradio as gr
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from util_funcs import get_length_param
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  def chat_function(message, length_of_the_answer, who_is_next, creativity): # model, tokenizer
7
 
 
27
  history = gr.get_state() or []
28
  chat_history_ids = torch.zeros((1, 0), dtype=torch.int) if history == [] else torch.tensor(history[-1][2], dtype=torch.long)
29
 
 
 
 
 
 
 
30
  # encode the new user input, add parameters and return a tensor in Pytorch
31
  if len(input_user) != 0:
32
 
 
39
 
40
  if next_who == "G":
41
 
 
42
  # encode the new user input, add parameters and return a tensor in Pytorch
43
  new_user_input_ids = tokenizer.encode(f"|1|{next_len}|", return_tensors="pt")
44
  # append the new user input tokens to the chat history
 
80
  html += f"<div class='resp_msg'>{resp_msg}</div>"
81
  html += "</div>"
82
  return html
83
+
84
+
85
+
86
 
87
 
88
  # Download checkpoint:
89
+ checkpoint = "Kirili4ik/ruDialoGpt3-medium-finetuned-telegram-6ep"
90
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
91
  model = AutoModelForCausalLM.from_pretrained(checkpoint)
92
  model = model.eval()
93
 
94
+ # Gradio
95
  checkbox_group = gr.inputs.CheckboxGroup(['Kirill', 'Me'], default=['Kirill'], type="value", label=None)
 
 
 
96
  title = "Chat with Kirill (in Russian)"
97
  description = "Тут можно поболтать со мной. Но вместо меня бот. Оставь message пустым, чтобы Кирилл продолжил говорить. Подбробнее о технике по ссылке внизу."
98
  article = "<p style='text-align: center'><a href='https://github.com/Kirili4ik/ruDialoGpt3-finetune-colab'>Github with fine-tuning GPT-2 on your chat</a></p>"
99
  examples = [
100
+ ["Привет, как дела?", 'medium', 'Kirill', 0.5],
101
  ["Сколько тебе лет?", 'medium', 'Kirill', 0.3],
102
  ]
103
 
104
+ iface = gr.Interface(chat_function,
105
+ [
106
+ "text",
107
+ gr.inputs.Radio(["short", "medium", "long"], default='medium'),
108
  gr.inputs.Radio(["Kirill", "Me"], default='Kirill'),
109
+ gr.inputs.Slider(0, 1, default=0.5)
110
+ ],
111
+ "html",
112
  title=title, description=description, article=article, examples=examples,
113
  css= """
114
  .chatbox {display:flex;flex-direction:column}
115
  .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
116
  .user_msg {background-color:cornflowerblue;color:white;align-self:start}
117
  .resp_msg {background-color:lightgray;align-self:self-end}
118
+ """,
119
+ allow_screenshot=True,
120
  allow_flagging=False
121
  )
122
 
123
+ if __name__ == "__main__":
124
+ iface.launch()
util_funcs.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_length_param(text: str, tokenizer) -> str:
2
+ """Maps text to 1 of 4 buckets based on length after encoding.
3
+
4
+ Parameters
5
+ ----------
6
+ text: str
7
+ The text to be given 1 of 4 length parameters.
8
+
9
+ tokenizer: HuggingFace tokenizer
10
+ Tokenizer that used to compute the length of the text after encoding.
11
+ For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html
12
+
13
+ Returns
14
+ -------
15
+ len_param: str
16
+ One of four buckets:
17
+ '1' for short, '2' for medium, '3' for long texts and '-' for all others.
18
+ """
19
+ tokens_count = len(tokenizer.encode(text))
20
+ if tokens_count <= 15:
21
+ len_param = '1'
22
+ elif tokens_count <= 50:
23
+ len_param = '2'
24
+ elif tokens_count <= 256:
25
+ len_param = '3'
26
+ else:
27
+ len_param = '-'
28
+ return len_param
29
+
30
+
31
+ def get_user_param(text: dict, machine_name_in_chat: str) -> str:
32
+ """Maps text by 1/0 for it to be the person or the machine in the dialogue
33
+
34
+ Parameters
35
+ ----------
36
+ text: Dict[..., 'from', ...]
37
+ Dict containing field 'from' with the name of the user who sent the message
38
+
39
+ machine_name_in_chat: str
40
+ Str with the name of the machine - it will be predicted
41
+ """
42
+ if text['from'] == machine_name_in_chat:
43
+ return '1' # machine
44
+ else:
45
+ return '0' # human
46
+
47
+
48
+ def build_text_file(data_json: dict, dest_path: str,
49
+ tokenizer, machine_name_in_chat='Кирилл Гельван'):
50
+ """Create a text file for training in special format for ruDialoGPT-3.
51
+
52
+ Parameters
53
+ ----------
54
+ data_json: dict
55
+ Dict containing 'text' (message) and 'from' (user who sent the message)
56
+
57
+ dest_path: str
58
+ String containing path to write data there
59
+
60
+ tokenizer: HuggingFace tokenizer
61
+ Tokenizer that used to compute the length of the text after encoding.
62
+ For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html
63
+ """
64
+ f = open(dest_path, 'w')
65
+ new_data = ''
66
+ for i in range(len(data_json) - 1):
67
+ message, next_message = data_json[i], data_json[i+1]
68
+ if message['text'] == '' or type(message['text']) != str:
69
+ continue
70
+ if next_message['text'] == '' or type(next_message['text']) != str:
71
+ continue
72
+
73
+ user = get_user_param(message, machine_name_in_chat=machine_name_in_chat)
74
+ length = get_length_param(data_json[i+1]['text'], tokenizer)
75
+ message_text = re.sub(r"\n", ". ", message['text'])
76
+ new_data += f"|{user}|{length}|{message_text}{tokenizer.eos_token}" + "\n"
77
+
78
+ f.write(new_data)
79
+
80
+
81
+ def load_dataset(train_path, test_path, tokenizer):
82
+ """Creates train and test PyTorch datasets and collate_fn using HuggingFace.
83
+
84
+ Parameters
85
+ ----------
86
+ train_path: str
87
+ String containing path to train data
88
+
89
+ test_path: str
90
+ String containing path to test data
91
+
92
+ tokenizer: HuggingFace tokenizer
93
+ Tokenizer that used to compute the length of the text after encoding.
94
+ For more info ee https://huggingface.co/transformers/main_classes/tokenizer.html
95
+ """
96
+ train_dataset = TextDataset(
97
+ tokenizer = tokenizer,
98
+ file_path = train_path,
99
+ block_size = 256)
100
+
101
+ test_dataset = TextDataset(
102
+ tokenizer = tokenizer,
103
+ file_path = test_path,
104
+ block_size = 256)
105
+
106
+ data_collator = DataCollatorForLanguageModeling(
107
+ tokenizer=tokenizer, mlm=False
108
+ )
109
+ return train_dataset, test_dataset, data_collator