Update README.md
Browse files
README.md
CHANGED
@@ -16,81 +16,45 @@ GITHUB_ACTIONS=true pip install auto-gptq
|
|
16 |
Пример кода для использования модели в генерации ответа:
|
17 |
|
18 |
```python
|
19 |
-
from transformers import AutoTokenizer
|
20 |
-
from auto_gptq import AutoGPTQForCausalLM
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
"role": "bot",
|
54 |
-
"content": message
|
55 |
-
})
|
56 |
-
|
57 |
-
def get_prompt(self, tokenizer):
|
58 |
-
final_text = ""
|
59 |
-
for message in self.messages:
|
60 |
-
message_text = self.message_template.format(**message)
|
61 |
-
final_text += message_text
|
62 |
-
final_text += tokenizer.decode([self.start_token_id, self.bot_token_id])
|
63 |
-
return final_text.strip()
|
64 |
-
|
65 |
-
|
66 |
-
def generate(model, tokenizer, prompt, generation_config):
|
67 |
-
data = tokenizer(prompt, return_tensors="pt")
|
68 |
-
data = {k: v.to(model.device) for k, v in data.items()}
|
69 |
-
output_ids = model.generate(
|
70 |
-
**data,
|
71 |
-
generation_config=generation_config
|
72 |
-
)[0]
|
73 |
-
output_ids = output_ids[len(data["input_ids"][0]):]
|
74 |
-
output = tokenizer.decode(output_ids, skip_special_tokens=True)
|
75 |
-
return output.strip()
|
76 |
-
|
77 |
-
|
78 |
-
MODEL_NAME = "gurgutan/saiga2-13b-4bit"
|
79 |
-
DEFAULT_MESSAGE_TEMPLATE = "<s>{role}\n{content}</s>\n"
|
80 |
-
DEFAULT_SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
|
81 |
-
|
82 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
|
83 |
-
model = AutoGPTQForCausalLM.from_quantized(MODEL_NAME, device="cuda:0", use_safetensors=True, use_triton=False)
|
84 |
-
generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
|
85 |
model.eval()
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
prompt = conversation.get_prompt(tokenizer)
|
91 |
-
output = generate(model, tokenizer, prompt, generation_config)
|
92 |
-
print(inp)
|
93 |
-
print(output)
|
94 |
|
95 |
```
|
96 |
# Original model: [saiga2-13B-lora](https://huggingface.co/IlyaGusev/saiga2_13b_lora)
|
|
|
16 |
Пример кода для использования модели в генерации ответа:
|
17 |
|
18 |
```python
|
19 |
+
from transformers import AutoTokenizer
|
20 |
+
from auto_gptq import AutoGPTQForCausalLM
|
21 |
+
|
22 |
+
device = "cuda:0"
|
23 |
+
quantized_model_dir = "saiga2-13b-4bit"
|
24 |
+
tokenizer = AutoTokenizer.from_pretrained(quantized_model_dir, use_fast=True)
|
25 |
+
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device=device)
|
26 |
+
|
27 |
+
def generate_answer(model, tokenizer, request: str):
|
28 |
+
s = f"system\n{config.system_prompt}</s>\n" + \
|
29 |
+
f"<s>user\n{request}</s>\n" + \
|
30 |
+
f"<s>bot\n"
|
31 |
+
request_tokens = tokenizer(s, return_tensors="pt")
|
32 |
+
del request_tokens['token_type_ids']
|
33 |
+
del request_tokens['attention_mask']
|
34 |
+
request_tokens = request_tokens.to(model.device)
|
35 |
+
answer_tokens = model.generate(**request_tokens,
|
36 |
+
num_beams=4,
|
37 |
+
top_k=32,
|
38 |
+
temperature=0.6,
|
39 |
+
repetition_penalty=1.2,
|
40 |
+
no_repeat_ngram_size=15,
|
41 |
+
max_new_tokens=1536,
|
42 |
+
pad_token_id=tokenizer.eos_token_id)[0]
|
43 |
+
print(request)
|
44 |
+
answer_tokens = answer_tokens[len(request_tokens[0]):-1]
|
45 |
+
answer = tokenizer.decode(answer_tokens).strip()
|
46 |
+
print(answer)
|
47 |
+
return answer
|
48 |
+
|
49 |
+
model_name = "gurgutan/saiga2-13b-4bit"
|
50 |
+
system_prompt = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
|
51 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
52 |
+
model = AutoGPTQForCausalLM.from_quantized(model_name, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
model.eval()
|
54 |
|
55 |
+
user_text = "Сочини стих, который начинается словами: Буря мглою небо кроет"
|
56 |
+
answer_text = generate_answer(model, tokenizer, user_text)
|
57 |
+
print(answer_text)
|
|
|
|
|
|
|
|
|
58 |
|
59 |
```
|
60 |
# Original model: [saiga2-13B-lora](https://huggingface.co/IlyaGusev/saiga2_13b_lora)
|