"Qwen/Qwen2-0.5B-Instruct" from threading import Thread from models.base_model import Simulator from transformers import TextIteratorStreamer from transformers import AutoModelForCausalLM, AutoTokenizer class Qwen2Simulator(Simulator): def __init__(self, model_name_or_path): """ 在传递 device_map 时,low_cpu_mem_usage 会自动设置为 True """ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.model = AutoModelForCausalLM.from_pretrained( model_name_or_path, torch_dtype="auto", device_map="auto" ) self.model.eval() self.generation_kwargs = dict( do_sample=True, temperature=0.7, # repetition_penalty= max_length=500, max_new_tokens=20 ) def generate_query(self, messages, stream=True): """ :param messages: :return: """ assert messages[-1]["role"] != "user" inputs = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False, ) inputs = inputs + "<|im_start|>user\n" input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to(self.model.device) streamer = TextIteratorStreamer(tokenizer=self.tokenizer, skip_prompt=True, timeout=120.0, skip_special_tokens=True) stream_generation_kwargs = dict( input_ids=input_ids, streamer=streamer ).update(self.generation_kwargs) thread = Thread(target=self.model.generate, kwargs=stream_generation_kwargs) thread.start() for new_text in streamer: print(new_text) yield new_text # return self._generate(input_ids) def generate_response(self, messages, stream=True): assert messages[-1]["role"] == "user" input_ids = self.tokenizer.apply_chat_template( messages, tokenize=True, return_tensors="pt", add_generation_prompt=True ).to(self.model.device) streamer = TextIteratorStreamer( tokenizer=self.tokenizer, # skip_prompt=True, # timeout=120.0, # skip_special_tokens=True ) generation_kwargs = dict( input_ids=input_ids, streamer=streamer ).update(self.generation_kwargs) print(generation_kwargs) thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() for new_text in streamer: print(new_text) yield new_text def _generate(self, input_ids): input_ids_length = input_ids.shape[-1] response = self.model.generate(input_ids=input_ids, **self.generation_kwargs) return self.tokenizer.decode(response[0][input_ids_length:], skip_special_tokens=True) bot = Qwen2Simulator(r"E:\data_model\Qwen2-0.5B-Instruct") # bot = Qwen2Simulator("Qwen/Qwen2-0.5B-Instruct") if __name__ == "__main__": messages = [ {"role": "system", "content": "you are a helpful assistant"}, {"role": "user", "content": "hi, what your name"} ] streamer = bot.generate_response(messages) # print(output) # messages = [ # {"role": "system", "content": "you are a helpful assistant"}, # {"role": "user", "content": "hi, what your name"}, # {"role": "assistant", "content": "My name is Jordan"} # ] # streamer = bot.generate_query(messages) print(list(streamer))