yurezsml commited on
Commit
056b8c6
1 Parent(s): 3463fd4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -0
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import transformers
4
+ from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
5
+
6
+ from peft import (
7
+ PeftModel,
8
+ LoraConfig,
9
+ get_peft_model,
10
+ prepare_model_for_kbit_training
11
+ )
12
+
13
+ import bs4
14
+ import requests
15
+ from typing import List
16
+
17
+ import nltk
18
+ from nltk import sent_tokenize
19
+
20
+ from tqdm import tqdm
21
+
22
+ import numpy as np
23
+
24
+ import torch
25
+
26
+ import faiss
27
+
28
+ import re
29
+
30
+ import unicodedata
31
+
32
+ import gradio as gr
33
+ import asyncio
34
+
35
+ device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ device
37
+
38
+ base_model_id = "microsoft/phi-2"
39
+
40
+ bnb_config = BitsAndBytesConfig(load_in_4bit=True,
41
+ bnb_4bit_quant_type='nf4',
42
+ bnb_4bit_compute_dtype='float16',
43
+ bnb_4bit_use_double_quant=True)
44
+
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ base_model_id,
47
+ device_map='auto',
48
+ quantization_config=bnb_config,
49
+ trust_remote_code=True
50
+ )
51
+
52
+ ft_model = PeftModel.from_pretrained(model, "yurezsml/phi2_chan")
53
+
54
+ def remove_accents(input_str):
55
+ nfkd_form = unicodedata.normalize('NFKD', input_str)
56
+ return u"".join([c for c in nfkd_form if not unicodedata.combining(c)])
57
+
58
+ def preprocess(text):
59
+ text = text.lower()
60
+ temp = remove_accents(text)
61
+ text = text.replace('\xa0', ' ')
62
+ text = text.replace('\n\n', '\n')
63
+ text = text.replace('()', '')
64
+ text = text.replace('[]', '')
65
+ text = re.sub("[\(\[].*?[\)\]]", "", text)
66
+ text = text.replace('а́', 'а')
67
+ return text
68
+
69
+ def split_text(text: str, n=2, character=" ") -> List[str]:
70
+ text = preprocess(text)
71
+
72
+ all_sentences = sent_tokenize(text)
73
+ return [' '.join(all_sentences[i : i + n]) for i in range(0, len(all_sentences), 2)]
74
+
75
+
76
+ def split_documents(documents: List[str]) -> list:
77
+ texts = []
78
+ for text in documents:
79
+ if text is not None:
80
+ for passage in split_text(text):
81
+ texts.append(passage)
82
+
83
+ return texts
84
+
85
+
86
+ def embed(text, model, tokenizer):
87
+ encoded_input = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt').to(model.device)
88
+ with torch.no_grad():
89
+ model_output = model(**encoded_input)
90
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
91
+ input_mask_expanded = encoded_input['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()
92
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
93
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
94
+ return sum_embeddings / sum_mask
95
+
96
+ response = requests.get("https://en.wikipedia.org/wiki/Chandler_Bing")
97
+
98
+ base_text = ''
99
+
100
+ if response:
101
+ html = bs4.BeautifulSoup(response.text, 'html.parser')
102
+
103
+ title = html.select("#firstHeading")[0].text
104
+ paragraphs = html.select("p")
105
+ for para in paragraphs:
106
+ base_text = base_text + para.text
107
+
108
+ fact_coh_tokenizer = AutoTokenizer.from_pretrained("DeepPavlov/bert-base-multilingual-cased-sentence")
109
+ fact_coh_model = AutoModel.from_pretrained("DeepPavlov/bert-base-multilingual-cased-sentence")
110
+ fact_coh_model.to(device)
111
+
112
+ nltk.download('punkt')
113
+ subsample_documents = split_documents([base_text])
114
+
115
+ batch_size = 8
116
+ total_batches = len(subsample_documents) // batch_size + (0 if len(subsample_documents) % batch_size == 0 else 1)
117
+
118
+ base = list()
119
+ for i in tqdm(range(0, len(subsample_documents), batch_size), total=total_batches, desc="Processing Batches"):
120
+ batch_texts = subsample_documents[i:i + batch_size]
121
+ base.extend(embed(batch_texts, fact_coh_model, fact_coh_tokenizer))
122
+
123
+ base = np.array([vector.cpu().numpy() for vector in base])
124
+
125
+ index = faiss.IndexFlatL2(base.shape[1])
126
+ index.add(base)
127
+
128
+ async def get_context(subsample_documents, query, index, model, tokenizer):
129
+ k = 5
130
+ xq = embed(query.lower(), model, tokenizer).cpu().numpy()
131
+ D, I = index.search(xq.reshape(1, 768), k)
132
+ return subsample_documents[I[0][0]]
133
+
134
+ async def get_prompt(question, use_rag, answers_history: list[str]):
135
+ eval_prompt = '###system: answer the question as Chandler. '
136
+ for idx, text in enumerate(answers_history):
137
+ if idx % 2 == 0:
138
+ eval_prompt = eval_prompt + f' ###question: {text}'
139
+ else:
140
+ eval_prompt = eval_prompt + f' ###answer: {text} '
141
+ if use_rag:
142
+ context = await asyncio.wait_for(get_context(subsample_documents, question, index, fact_coh_model, fact_coh_tokenizer), timeout=60)
143
+ eval_prompt = eval_prompt + f' Chandler. {context}'
144
+ eval_prompt = eval_prompt + f' ###question: {question} '
145
+ eval_prompt = ' '.join(eval_prompt.split())
146
+ return eval_prompt
147
+
148
+ async def get_answer(question, use_rag, answers_history: list[str]):
149
+ eval_prompt = await asyncio.wait_for(get_prompt(question, use_rag, answers_history), timeout=60)
150
+ model_input = tokenizer(eval_prompt, return_tensors="pt").to(device)
151
+ ft_model.eval()
152
+ with torch.no_grad():
153
+ answer = tokenizer.decode(ft_model.generate(**model_input, max_new_tokens=30, repetition_penalty=1.11)[0], skip_special_tokens=True) + '\n'
154
+ answer = ' '.join(answer.split())
155
+ if eval_prompt in answer:
156
+ answer = answer.replace(eval_prompt,'')
157
+ answer = answer.split('###answer')[1]
158
+
159
+ dialog = ''
160
+ for idx, text in enumerate(answers_history):
161
+ if idx % 2 == 0:
162
+ dialog = dialog + f'you: {text}\n'
163
+ else:
164
+ dialog = dialog + f'Chandler: {text}\n'
165
+ dialog = dialog + f'you: {question}\n'
166
+ dialog = dialog + f'Chandler: {answer}\n'
167
+
168
+ answers_history.append(question)
169
+ answers_history.append(answer)
170
+
171
+ return dialog, answers_history
172
+
173
+ async def async_proc(question, use_rag, answers_history: list[str]):
174
+ try:
175
+ return await asyncio.wait_for(get_answer(question, use_rag, answers_history), timeout=60)
176
+ except asyncio.TimeoutError:
177
+ return "Processing timed out.", answers_history
178
+
179
+ gr.Interface(
180
+ fn=async_proc,
181
+ inputs=[
182
+ gr.Textbox(
183
+ label="Question",
184
+ ),
185
+ gr.Checkbox(label="Use RAG", info="Pick to RAG to improve factual coherence"),
186
+ gr.State(value=[]),
187
+ ],
188
+ outputs=[
189
+ gr.Textbox(
190
+ label="Chat"
191
+ ),
192
+ gr.State(),
193
+ ],
194
+ title="Асинхронный сервис для чат-бота по сериалу Друзья",
195
+ concurrency_limit=5
196
+ ).queue().launch(share=True, debug=True)