Turkunov Y commited on
Commit
b746c52
1 Parent(s): 4e2c4f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -3,15 +3,20 @@ from textPreprocessing import text2prompt
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
  import torch
5
 
 
 
 
6
  bnb_config = BitsAndBytesConfig(
7
  load_in_4bit=True,
8
  bnb_4bit_use_double_quant=True,
9
  bnb_4bit_quant_type="fp4",
10
  bnb_4bit_compute_dtype=torch.bfloat16
11
- )
12
 
13
- model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1",
14
- quantization_config=bnb_config)
 
 
15
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
16
 
17
  def predict(input_text, t, m):
@@ -25,7 +30,7 @@ def predict(input_text, t, m):
25
  - Instruct-based модель
26
  """
27
  prompt = text2prompt(input_text)
28
- inputs = tokenizer(prompt, return_tensors="pt")
29
  generate_ids = model.generate(inputs.input_ids, max_new_tokens=128)
30
  answer = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
31
  return answer.replace(prompt, "")
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
  import torch
5
 
6
+ """
7
+ Необходимо раскомментить при досутпе к GPU
8
+
9
  bnb_config = BitsAndBytesConfig(
10
  load_in_4bit=True,
11
  bnb_4bit_use_double_quant=True,
12
  bnb_4bit_quant_type="fp4",
13
  bnb_4bit_compute_dtype=torch.bfloat16
14
+ )"""
15
 
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ "mistralai/Mistral-7B-Instruct-v0.1",
18
+ # quantization_config=bnb_config # Необходимо раскомментить при досутпе к GPU
19
+ )
20
  tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
21
 
22
  def predict(input_text, t, m):
 
30
  - Instruct-based модель
31
  """
32
  prompt = text2prompt(input_text)
33
+ inputs = tokenizer(prompt, return_tensors="np")
34
  generate_ids = model.generate(inputs.input_ids, max_new_tokens=128)
35
  answer = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
36
  return answer.replace(prompt, "")