|
"Qwen/Qwen2-0.5B-Instruct" |
|
|
|
from threading import Thread |
|
from models.base_model import Simulator |
|
|
|
from transformers import TextIteratorStreamer |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
class Qwen2Simulator(Simulator): |
|
|
|
def __init__(self, model_name_or_path): |
|
""" |
|
在传递 device_map 时,low_cpu_mem_usage 会自动设置为 True |
|
""" |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_name_or_path, |
|
torch_dtype="auto", |
|
device_map="auto" |
|
) |
|
self.model.eval() |
|
self.generation_kwargs = dict( |
|
do_sample=True, |
|
temperature=0.7, |
|
|
|
max_length=500, |
|
max_new_tokens=20 |
|
) |
|
|
|
def generate_query(self, messages, stream=True): |
|
""" |
|
:param messages: |
|
:return: |
|
""" |
|
assert messages[-1]["role"] != "user" |
|
inputs = self.tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=False, |
|
) |
|
inputs = inputs + "<|im_start|>user\n" |
|
input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to(self.model.device) |
|
|
|
streamer = TextIteratorStreamer(tokenizer=self.tokenizer, skip_prompt=True, timeout=120.0, |
|
skip_special_tokens=True) |
|
|
|
stream_generation_kwargs = dict( |
|
input_ids=input_ids, |
|
streamer=streamer |
|
).update(self.generation_kwargs) |
|
thread = Thread(target=self.model.generate, kwargs=stream_generation_kwargs) |
|
thread.start() |
|
|
|
for new_text in streamer: |
|
print(new_text) |
|
yield new_text |
|
|
|
|
|
def generate_response(self, messages, stream=True): |
|
assert messages[-1]["role"] == "user" |
|
input_ids = self.tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=True, |
|
return_tensors="pt", |
|
add_generation_prompt=True |
|
).to(self.model.device) |
|
|
|
streamer = TextIteratorStreamer( |
|
tokenizer=self.tokenizer, |
|
|
|
|
|
|
|
) |
|
|
|
generation_kwargs = dict( |
|
input_ids=input_ids, |
|
streamer=streamer |
|
).update(self.generation_kwargs) |
|
print(generation_kwargs) |
|
|
|
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
for new_text in streamer: |
|
print(new_text) |
|
yield new_text |
|
|
|
def _generate(self, input_ids): |
|
input_ids_length = input_ids.shape[-1] |
|
response = self.model.generate(input_ids=input_ids, **self.generation_kwargs) |
|
return self.tokenizer.decode(response[0][input_ids_length:], skip_special_tokens=True) |
|
|
|
|
|
bot = Qwen2Simulator(r"E:\data_model\Qwen2-0.5B-Instruct") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
messages = [ |
|
{"role": "system", "content": "you are a helpful assistant"}, |
|
{"role": "user", "content": "hi, what your name"} |
|
] |
|
streamer = bot.generate_response(messages) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(list(streamer)) |
|
|