gurgutan commited on
Commit
b6d6ec2
1 Parent(s): 2f0aa8b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +37 -73
README.md CHANGED
@@ -16,81 +16,45 @@ GITHUB_ACTIONS=true pip install auto-gptq
16
  Пример кода для использования модели в генерации ответа:
17
 
18
  ```python
19
- from transformers import AutoTokenizer, TextGenerationPipeline
20
- from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
21
-
22
-
23
- class Conversation:
24
- def __init__(
25
- self,
26
- message_template=DEFAULT_MESSAGE_TEMPLATE,
27
- system_prompt=DEFAULT_SYSTEM_PROMPT,
28
- start_token_id=1,
29
- bot_token_id=9225
30
- ):
31
- self.message_template = message_template
32
- self.start_token_id = start_token_id
33
- self.bot_token_id = bot_token_id
34
- self.messages = [{
35
- "role": "system",
36
- "content": system_prompt
37
- }]
38
-
39
- def get_start_token_id(self):
40
- return self.start_token_id
41
-
42
- def get_bot_token_id(self):
43
- return self.bot_token_id
44
-
45
- def add_user_message(self, message):
46
- self.messages.append({
47
- "role": "user",
48
- "content": message
49
- })
50
-
51
- def add_bot_message(self, message):
52
- self.messages.append({
53
- "role": "bot",
54
- "content": message
55
- })
56
-
57
- def get_prompt(self, tokenizer):
58
- final_text = ""
59
- for message in self.messages:
60
- message_text = self.message_template.format(**message)
61
- final_text += message_text
62
- final_text += tokenizer.decode([self.start_token_id, self.bot_token_id])
63
- return final_text.strip()
64
-
65
-
66
- def generate(model, tokenizer, prompt, generation_config):
67
- data = tokenizer(prompt, return_tensors="pt")
68
- data = {k: v.to(model.device) for k, v in data.items()}
69
- output_ids = model.generate(
70
- **data,
71
- generation_config=generation_config
72
- )[0]
73
- output_ids = output_ids[len(data["input_ids"][0]):]
74
- output = tokenizer.decode(output_ids, skip_special_tokens=True)
75
- return output.strip()
76
-
77
-
78
- MODEL_NAME = "gurgutan/saiga2-13b-4bit"
79
- DEFAULT_MESSAGE_TEMPLATE = "<s>{role}\n{content}</s>\n"
80
- DEFAULT_SYSTEM_PROMPT = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
81
-
82
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
83
- model = AutoGPTQForCausalLM.from_quantized(MODEL_NAME, device="cuda:0", use_safetensors=True, use_triton=False)
84
- generation_config = GenerationConfig.from_pretrained(MODEL_NAME)
85
  model.eval()
86
 
87
- input = "Сочини стих, который начинается словами: Буря мглою небо кроет"
88
- conversation = Conversation()
89
- conversation.add_user_message(input)
90
- prompt = conversation.get_prompt(tokenizer)
91
- output = generate(model, tokenizer, prompt, generation_config)
92
- print(inp)
93
- print(output)
94
 
95
  ```
96
  # Original model: [saiga2-13B-lora](https://huggingface.co/IlyaGusev/saiga2_13b_lora)
 
16
  Пример кода для использования модели в генерации ответа:
17
 
18
  ```python
19
+ from transformers import AutoTokenizer
20
+ from auto_gptq import AutoGPTQForCausalLM
21
+
22
+ device = "cuda:0"
23
+ quantized_model_dir = "saiga2-13b-4bit"
24
+ tokenizer = AutoTokenizer.from_pretrained(quantized_model_dir, use_fast=True)
25
+ model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device=device)
26
+
27
+ def generate_answer(model, tokenizer, request: str):
28
+ s = f"system\n{config.system_prompt}</s>\n" + \
29
+ f"<s>user\n{request}</s>\n" + \
30
+ f"<s>bot\n"
31
+ request_tokens = tokenizer(s, return_tensors="pt")
32
+ del request_tokens['token_type_ids']
33
+ del request_tokens['attention_mask']
34
+ request_tokens = request_tokens.to(model.device)
35
+ answer_tokens = model.generate(**request_tokens,
36
+ num_beams=4,
37
+ top_k=32,
38
+ temperature=0.6,
39
+ repetition_penalty=1.2,
40
+ no_repeat_ngram_size=15,
41
+ max_new_tokens=1536,
42
+ pad_token_id=tokenizer.eos_token_id)[0]
43
+ print(request)
44
+ answer_tokens = answer_tokens[len(request_tokens[0]):-1]
45
+ answer = tokenizer.decode(answer_tokens).strip()
46
+ print(answer)
47
+ return answer
48
+
49
+ model_name = "gurgutan/saiga2-13b-4bit"
50
+ system_prompt = "Ты — Сайга, русскоязычный автоматический ассистент. Ты разговариваешь с людьми и помогаешь им."
51
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
52
+ model = AutoGPTQForCausalLM.from_quantized(model_name, device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  model.eval()
54
 
55
+ user_text = "Сочини стих, который начинается словами: Буря мглою небо кроет"
56
+ answer_text = generate_answer(model, tokenizer, user_text)
57
+ print(answer_text)
 
 
 
 
58
 
59
  ```
60
  # Original model: [saiga2-13B-lora](https://huggingface.co/IlyaGusev/saiga2_13b_lora)