Spaces:
Runtime error
Runtime error
Turkunov Y
commited on
Commit
•
b746c52
1
Parent(s):
4e2c4f4
Update app.py
Browse files
app.py
CHANGED
@@ -3,15 +3,20 @@ from textPreprocessing import text2prompt
|
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
4 |
import torch
|
5 |
|
|
|
|
|
|
|
6 |
bnb_config = BitsAndBytesConfig(
|
7 |
load_in_4bit=True,
|
8 |
bnb_4bit_use_double_quant=True,
|
9 |
bnb_4bit_quant_type="fp4",
|
10 |
bnb_4bit_compute_dtype=torch.bfloat16
|
11 |
-
)
|
12 |
|
13 |
-
model = AutoModelForCausalLM.from_pretrained(
|
14 |
-
|
|
|
|
|
15 |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
|
16 |
|
17 |
def predict(input_text, t, m):
|
@@ -25,7 +30,7 @@ def predict(input_text, t, m):
|
|
25 |
- Instruct-based модель
|
26 |
"""
|
27 |
prompt = text2prompt(input_text)
|
28 |
-
inputs = tokenizer(prompt, return_tensors="
|
29 |
generate_ids = model.generate(inputs.input_ids, max_new_tokens=128)
|
30 |
answer = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
31 |
return answer.replace(prompt, "")
|
|
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
4 |
import torch
|
5 |
|
6 |
+
"""
|
7 |
+
Необходимо раскомментить при досутпе к GPU
|
8 |
+
|
9 |
bnb_config = BitsAndBytesConfig(
|
10 |
load_in_4bit=True,
|
11 |
bnb_4bit_use_double_quant=True,
|
12 |
bnb_4bit_quant_type="fp4",
|
13 |
bnb_4bit_compute_dtype=torch.bfloat16
|
14 |
+
)"""
|
15 |
|
16 |
+
model = AutoModelForCausalLM.from_pretrained(
|
17 |
+
"mistralai/Mistral-7B-Instruct-v0.1",
|
18 |
+
# quantization_config=bnb_config # Необходимо раскомментить при досутпе к GPU
|
19 |
+
)
|
20 |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
|
21 |
|
22 |
def predict(input_text, t, m):
|
|
|
30 |
- Instruct-based модель
|
31 |
"""
|
32 |
prompt = text2prompt(input_text)
|
33 |
+
inputs = tokenizer(prompt, return_tensors="np")
|
34 |
generate_ids = model.generate(inputs.input_ids, max_new_tokens=128)
|
35 |
answer = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
36 |
return answer.replace(prompt, "")
|