Turkunov Y commited on
Commit
a979393
1 Parent(s): eb009b7

Code for inference

Browse files
Files changed (3) hide show
  1. app.py +44 -0
  2. requirements.txt +2 -0
  3. textPreprocessing.py +21 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from textPreprocessing import text2prompt
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
+ import torch
5
+
6
+ bnb_config = BitsAndBytesConfig(
7
+ load_in_4bit=True,
8
+ bnb_4bit_use_double_quant=True,
9
+ bnb_4bit_quant_type="fp4",
10
+ bnb_4bit_compute_dtype=torch.bfloat16
11
+ )
12
+
13
+ model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1",
14
+ quantization_config=bnb_config)
15
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
16
+
17
+ def predict(input_text, t, m):
18
+ """
19
+ Вывести финансовую рекомендацию на основе:
20
+ input_text: str
21
+ - Контекст в виде новости из области экономики
22
+ t: tokenizer
23
+ - Токенизатор для модели
24
+ m: model
25
+ - Instruct-based модель
26
+ """
27
+ prompt = text2prompt(input_text)
28
+ inputs = tokenizer(prompt, return_tensors="pt")
29
+ generate_ids = model.generate(inputs.input_ids, max_new_tokens=128)
30
+ answer = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
31
+ return answer.replace(prompt, "")
32
+
33
+ gradio_app = gr.Interface(
34
+ predict,
35
+ inputs=gr.Textbox(
36
+ label="Входная новость", sources=['upload', 'webcam'], container=True,
37
+ lines=8, placeholder="Акции кредитного банка \"X\" обрушились в цене из-за дефолта по ипотечным кредитам"
38
+ ),
39
+ outputs=[gr.Label(label="Финансовая рекомендация на основе новости:")],
40
+ title="Finam Finetuned Mistral Instruct (FFMI)",
41
+ )
42
+
43
+ if __name__ == "__main__":
44
+ gradio_app.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ torch
textPreprocessing.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def text2prompt(text: str):
2
+ """
3
+ Применяет инструкцию в формате, который распознается
4
+ Mistral Instruct и выводит на выходе входной текст для модели
5
+ """
6
+
7
+ instruction = 'Ниже тебе дан текст, содержащий новость на тему экономики. ' \
8
+ 'На основе текста ты должен дать наиболее подходящую финансовую рекомендацию из списка ["buy", "sell", "long", "short"]. ' \
9
+ 'Твой ответ должен содержать только одно слово, которое будет подходящей рекомендацией из списка.'
10
+
11
+ bos_token = "<s>"
12
+ eos_token = "</s>"
13
+
14
+ full_prompt = ""
15
+ full_prompt += bos_token
16
+ full_prompt += "### Instruction:"
17
+ full_prompt += "\n" + instruction
18
+ full_prompt += "\n\n### Input:"
19
+ full_prompt += "\n" + text
20
+ full_prompt += "\n\n### Response:"
21
+ full_prompt += eos_token