|
import gc |
|
from threading import Thread |
|
|
|
import torch |
|
from transformers import TextIteratorStreamer |
|
|
|
|
|
@torch.inference_mode() |
|
def generate_stream_xft( |
|
model, |
|
tokenizer, |
|
params, |
|
device, |
|
context_len=8192, |
|
stream_interval=2, |
|
judge_sent_end=False, |
|
): |
|
prompt = params["prompt"] |
|
repetition_penalty = float(params.get("repetition_penalty", 1.0)) |
|
|
|
|
|
|
|
|
|
|
|
max_new_tokens = int(params.get("max_new_tokens", 4096)) |
|
echo = params.get("echo", True) |
|
|
|
inputs = tokenizer( |
|
prompt, return_tensors="pt", padding=model.config.padding |
|
).input_ids |
|
input_echo_len = len(inputs[0]) |
|
max_len = max_new_tokens + input_echo_len |
|
|
|
decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) |
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config) |
|
generation_kwargs = { |
|
"input_ids": inputs, |
|
"streamer": streamer, |
|
"max_length": max_len, |
|
"num_beams": model.config.beam_width, |
|
"length_penalty": repetition_penalty, |
|
"num_return_sequences": model.config.num_return_sequences, |
|
"early_stopping": model.config.early_stopping, |
|
"eos_token_id": model.config.eos_token_id, |
|
"pad_token_id": model.config.pad_token_id, |
|
} |
|
|
|
thread = Thread(target=model.model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
if echo: |
|
|
|
output = prompt |
|
else: |
|
output = "" |
|
i = 0 |
|
for i, new_text in enumerate(streamer): |
|
output += new_text |
|
yield { |
|
"text": output, |
|
"usage": { |
|
"prompt_tokens": input_echo_len, |
|
"completion_tokens": i, |
|
"total_tokens": input_echo_len + i, |
|
}, |
|
"finish_reason": None, |
|
} |
|
output = output.strip() |
|
if i == max_new_tokens - 1: |
|
finish_reason = "length" |
|
else: |
|
finish_reason = "stop" |
|
yield { |
|
"text": output, |
|
"usage": { |
|
"prompt_tokens": input_echo_len, |
|
"completion_tokens": i, |
|
"total_tokens": input_echo_len + i, |
|
}, |
|
"finish_reason": finish_reason, |
|
} |
|
gc.collect() |
|
|