xu song commited on
Commit
9dc0e21
1 Parent(s): 2c8aed7
Files changed (2) hide show
  1. models/cpp_qwen2.py +95 -0
  2. simulator.py +6 -22
models/cpp_qwen2.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/abetlen/llama-cpp-python/blob/main/examples/gradio_chat/local.py
3
+ https://github.com/awinml/llama-cpp-python-bindings
4
+ """
5
+
6
+ from simulator import Simulator
7
+ from llama_cpp import Llama
8
+ import llama_cpp.llama_tokenizer
9
+
10
+
11
+ class Qwen2Simulator(Simulator):
12
+
13
+ def __init__(self, model_name_or_path=None):
14
+ # self.llm = llama_cpp.Llama.from_pretrained(
15
+ # repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
16
+ # filename="*q8_0.gguf", #
17
+ # tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(
18
+ # "Qwen/Qwen1.5-0.5B-Chat"
19
+ # ),
20
+ # verbose=False,
21
+ # )
22
+
23
+ self.llm = Llama(
24
+ model_path="Qwen/Qwen1.5-0.5B-Chat-GGUF/qwen1_5-0_5b-chat-q8_0.gguf",
25
+ # n_gpu_layers=-1, # Uncomment to use GPU acceleration
26
+ # seed=1337, # Uncomment to set a specific seed
27
+ # n_ctx=2048, # Uncomment to increase the context window
28
+ tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(
29
+ "/workspace/czy/model_weights/Qwen1.5-0.5B-Chat/"
30
+ ),
31
+ verbose=False,
32
+ )
33
+
34
+
35
+ def generate_query(self, messages):
36
+ """
37
+ :param messages:
38
+ :return:
39
+ """
40
+ assert messages[-1]["role"] != "user"
41
+ inputs = self.tokenizer.apply_chat_template(
42
+ messages,
43
+ tokenize=False,
44
+ add_generation_prompt=False,
45
+ )
46
+ inputs = inputs + "<|im_start|>user\n"
47
+ return self._generate(inputs)
48
+ # for new_text in self._stream_generate(input_ids):
49
+ # yield new_text
50
+
51
+ def generate_response(self, messages):
52
+ assert messages[-1]["role"] == "user"
53
+ inputs = self.tokenizer.apply_chat_template(
54
+ messages,
55
+ tokenize=False,
56
+ add_generation_prompt=True
57
+ )
58
+
59
+ return self._generate(inputs)
60
+ # for new_text in self._stream_generate(input_ids):
61
+ # yield new_text
62
+
63
+
64
+ def _generate(self, inputs):
65
+ # stream=False
66
+ output = self.llm(
67
+ inputs,
68
+ max_tokens=20,
69
+ temperature=0.7,
70
+ stop=["<|im_end|>"]
71
+ )
72
+ output_text = output["choices"][0]["text"]
73
+ return output_text
74
+
75
+
76
+
77
+ bot = Qwen2Simulator(r"E:\data_model\Qwen2-0.5B-Instruct")
78
+
79
+
80
+ if __name__ == "__main__":
81
+
82
+ messages = [
83
+ {"role": "system", "content": "you are a helpful assistant"},
84
+ {"role": "user", "content": "What is the capital of France?"}
85
+ ]
86
+ output = bot.generate_response(messages)
87
+ print(output)
88
+
89
+ messages = [
90
+ {"role": "system", "content": "you are a helpful assistant"},
91
+ {"role": "user", "content": "hi, what your name"},
92
+ {"role": "assistant", "content": "My name is Jordan"}
93
+ ]
94
+ output = bot.generate_query(messages)
95
+ print(output)
simulator.py CHANGED
@@ -1,38 +1,22 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
2
 
3
 
4
  class Simulator:
5
 
6
  def __init__(self, model_name_or_path):
7
- """
8
- 在传递 device_map 时,low_cpu_mem_usage 会自动设置为 True
9
- """
10
 
11
- self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
12
- self.model = AutoModelForCausalLM.from_pretrained(
13
- model_name_or_path,
14
- torch_dtype="auto",
15
- device_map="auto"
16
- )
17
- self.model.eval()
18
- self.generation_kwargs = dict(
19
- do_sample=True,
20
- temperature=0.7,
21
- max_length=500,
22
- max_new_tokens=10
23
- )
24
 
25
- def generate_query(self, history):
26
  """ user simulator
27
- :param history:
28
  :return:
29
  """
30
  raise NotImplementedError
31
 
32
- def generate_response(self, input, history):
33
  """ assistant simulator
34
- :param input:
35
- :param history:
36
  :return:
37
  """
38
  raise NotImplementedError
 
1
+
2
 
3
 
4
  class Simulator:
5
 
6
  def __init__(self, model_name_or_path):
7
+ raise NotImplementedError
 
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ def generate_query(self, messages):
11
  """ user simulator
12
+ :param messages:
13
  :return:
14
  """
15
  raise NotImplementedError
16
 
17
+ def generate_response(self, messages):
18
  """ assistant simulator
19
+ :param messages:
 
20
  :return:
21
  """
22
  raise NotImplementedError