Spaces:
Running
Running
"""Inference for FastChat models.""" | |
import abc | |
import gc | |
import json | |
import math | |
import os | |
import sys | |
import time | |
from typing import Iterable, Optional, Dict | |
import warnings | |
import psutil | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
LlamaTokenizer, | |
LlamaForCausalLM, | |
AutoModel, | |
AutoModelForSeq2SeqLM, | |
T5Tokenizer, | |
AutoConfig, | |
) | |
from transformers.generation.logits_process import ( | |
LogitsProcessorList, | |
RepetitionPenaltyLogitsProcessor, | |
TemperatureLogitsWarper, | |
TopKLogitsWarper, | |
TopPLogitsWarper, | |
) | |
from src.conversation import get_conv_template, SeparatorStyle | |
from src.model.model_adapter import ( | |
load_model, | |
get_conversation_template, | |
get_generate_stream_function, | |
) | |
from src.modules.awq import AWQConfig | |
from src.modules.gptq import GptqConfig | |
from src.modules.exllama import ExllamaConfig | |
from src.modules.xfastertransformer import XftConfig | |
from src.utils import is_partial_stop, is_sentence_complete, get_context_length | |
def prepare_logits_processor( | |
temperature: float, repetition_penalty: float, top_p: float, top_k: int | |
) -> LogitsProcessorList: | |
processor_list = LogitsProcessorList() | |
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. | |
if temperature >= 1e-5 and temperature != 1.0: | |
processor_list.append(TemperatureLogitsWarper(temperature)) | |
if repetition_penalty > 1.0: | |
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) | |
if 1e-8 <= top_p < 1.0: | |
processor_list.append(TopPLogitsWarper(top_p)) | |
if top_k > 0: | |
processor_list.append(TopKLogitsWarper(top_k)) | |
return processor_list | |
def generate_stream( | |
model, | |
tokenizer, | |
params: Dict, | |
device: str, | |
context_len: int, | |
stream_interval: int = 2, | |
judge_sent_end: bool = False, | |
): | |
if hasattr(model, "device"): | |
device = model.device | |
# Read parameters | |
prompt = params["prompt"] | |
len_prompt = len(prompt) | |
temperature = float(params.get("temperature", 1.0)) | |
repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
top_p = float(params.get("top_p", 1.0)) | |
top_k = int(params.get("top_k", -1)) # -1 means disable | |
max_new_tokens = int(params.get("max_new_tokens", 256)) | |
logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1. | |
echo = bool(params.get("echo", True)) | |
stop_str = params.get("stop", None) | |
stop_token_ids = params.get("stop_token_ids", None) or [] | |
if tokenizer.eos_token_id not in stop_token_ids: | |
stop_token_ids.append(tokenizer.eos_token_id) | |
logits_processor = prepare_logits_processor( | |
temperature, repetition_penalty, top_p, top_k | |
) | |
input_ids = tokenizer(prompt).input_ids | |
if model.config.is_encoder_decoder: | |
max_src_len = context_len | |
else: # truncate | |
max_src_len = context_len - max_new_tokens - 1 | |
input_ids = input_ids[-max_src_len:] | |
output_ids = list(input_ids) | |
input_echo_len = len(input_ids) | |
if model.config.is_encoder_decoder: | |
if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models. | |
raise NotImplementedError | |
encoder_output = model.encoder( | |
input_ids=torch.as_tensor([input_ids], device=device) | |
)[0] | |
start_ids = torch.as_tensor( | |
[[model.generation_config.decoder_start_token_id]], | |
dtype=torch.int64, | |
device=device, | |
) | |
else: | |
start_ids = torch.as_tensor([input_ids], device=device) | |
past_key_values = out = None | |
token_logprobs = [None] # The first token has no logprobs. | |
sent_interrupt = False | |
finish_reason = None | |
stopped = False | |
for i in range(max_new_tokens): | |
if i == 0: # prefill | |
if model.config.is_encoder_decoder: | |
out = model.decoder( | |
input_ids=start_ids, | |
encoder_hidden_states=encoder_output, | |
use_cache=True, | |
) | |
logits = model.lm_head(out[0]) | |
else: | |
out = model(input_ids=start_ids, use_cache=True) | |
logits = out.logits | |
past_key_values = out.past_key_values | |
if logprobs is not None: | |
# Prefull logprobs for the prompt. | |
shift_input_ids = start_ids[..., 1:].contiguous() | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist() | |
for label_id, logit in zip( | |
shift_input_ids[0].tolist(), shift_logits[0] | |
): | |
token_logprobs.append(logit[label_id]) | |
else: # decoding | |
if model.config.is_encoder_decoder: | |
out = model.decoder( | |
input_ids=torch.as_tensor( | |
[[token] if not sent_interrupt else output_ids], | |
device=device, | |
), | |
encoder_hidden_states=encoder_output, | |
use_cache=True, | |
past_key_values=past_key_values if not sent_interrupt else None, | |
) | |
sent_interrupt = False | |
logits = model.lm_head(out[0]) | |
else: | |
out = model( | |
input_ids=torch.as_tensor( | |
[[token] if not sent_interrupt else output_ids], | |
device=device, | |
), | |
use_cache=True, | |
past_key_values=past_key_values if not sent_interrupt else None, | |
) | |
sent_interrupt = False | |
logits = out.logits | |
past_key_values = out.past_key_values | |
if logits_processor: | |
if repetition_penalty > 1.0: | |
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) | |
else: | |
tmp_output_ids = None | |
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] | |
else: | |
last_token_logits = logits[0, -1, :] | |
if device == "mps": | |
# Switch to CPU by avoiding some bugs in mps backend. | |
last_token_logits = last_token_logits.float().to("cpu") | |
if temperature < 1e-5 or top_p < 1e-8: # greedy | |
_, indices = torch.topk(last_token_logits, 2) | |
tokens = [int(index) for index in indices.tolist()] | |
else: | |
probs = torch.softmax(last_token_logits, dim=-1) | |
indices = torch.multinomial(probs, num_samples=2) | |
tokens = [int(token) for token in indices.tolist()] | |
token = tokens[0] | |
output_ids.append(token) | |
if logprobs is not None: | |
# Cannot use last_token_logits because logprobs is based on raw logits. | |
token_logprobs.append( | |
torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist() | |
) | |
if token in stop_token_ids: | |
stopped = True | |
else: | |
stopped = False | |
# Yield the output tokens | |
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: | |
if echo: | |
tmp_output_ids = output_ids | |
rfind_start = len_prompt | |
else: | |
tmp_output_ids = output_ids[input_echo_len:] | |
rfind_start = 0 | |
output = tokenizer.decode( | |
tmp_output_ids, | |
skip_special_tokens=True, | |
spaces_between_special_tokens=False, | |
clean_up_tokenization_spaces=True, | |
) | |
ret_logprobs = None | |
if logprobs is not None: | |
ret_logprobs = { | |
"text_offset": [], | |
"tokens": [ | |
tokenizer.decode(token) | |
for token in ( | |
output_ids if echo else output_ids[input_echo_len:] | |
) | |
], | |
"token_logprobs": token_logprobs | |
if echo | |
else token_logprobs[input_echo_len:], | |
"top_logprobs": [{}] | |
* len(token_logprobs if echo else token_logprobs[input_echo_len:]), | |
} | |
# Compute text_offset | |
curr_pos = 0 | |
for text in ret_logprobs["tokens"]: | |
ret_logprobs["text_offset"].append(curr_pos) | |
curr_pos += len(text) | |
# TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way | |
if judge_sent_end and stopped and not is_sentence_complete(output): | |
if len(tokens) > 1: | |
token = tokens[1] | |
output_ids[-1] = token | |
else: | |
output_ids.pop() | |
stopped = False | |
sent_interrupt = True | |
partially_stopped = False | |
if stop_str: | |
if isinstance(stop_str, str): | |
pos = output.rfind(stop_str, rfind_start) | |
if pos != -1: | |
output = output[:pos] | |
stopped = True | |
else: | |
partially_stopped = is_partial_stop(output, stop_str) | |
elif isinstance(stop_str, Iterable): | |
for each_stop in stop_str: | |
pos = output.rfind(each_stop, rfind_start) | |
if pos != -1: | |
output = output[:pos] | |
stopped = True | |
break | |
else: | |
partially_stopped = is_partial_stop(output, each_stop) | |
if partially_stopped: | |
break | |
else: | |
raise ValueError("Invalid stop field type.") | |
# Prevent yielding partial stop sequence | |
if not partially_stopped: | |
yield { | |
"text": output, | |
"logprobs": ret_logprobs, | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": i, | |
"total_tokens": input_echo_len + i, | |
}, | |
"finish_reason": None, | |
} | |
if stopped: | |
break | |
# Finish stream event, which contains finish reason | |
else: | |
finish_reason = "length" | |
if stopped: | |
finish_reason = "stop" | |
yield { | |
"text": output, | |
"logprobs": ret_logprobs, | |
"usage": { | |
"prompt_tokens": input_echo_len, | |
"completion_tokens": i, | |
"total_tokens": input_echo_len + i, | |
}, | |
"finish_reason": finish_reason, | |
} | |
# Clean | |
del past_key_values, out | |
gc.collect() | |
torch.cuda.empty_cache() | |
if device == "xpu": | |
torch.xpu.empty_cache() | |
if device == "npu": | |
torch.npu.empty_cache() | |
class ChatIO(abc.ABC): | |
def prompt_for_input(self, role: str) -> str: | |
"""Prompt for input from a role.""" | |
def prompt_for_output(self, role: str): | |
"""Prompt for output from a role.""" | |
def stream_output(self, output_stream): | |
"""Stream output.""" | |
def print_output(self, text: str): | |
"""Print output.""" | |
def chat_loop( | |
model_path: str, | |
device: str, | |
num_gpus: int, | |
max_gpu_memory: str, | |
dtype: Optional[torch.dtype], | |
load_8bit: bool, | |
cpu_offloading: bool, | |
conv_template: Optional[str], | |
conv_system_msg: Optional[str], | |
temperature: float, | |
repetition_penalty: float, | |
max_new_tokens: int, | |
chatio: ChatIO, | |
gptq_config: Optional[GptqConfig] = None, | |
awq_config: Optional[AWQConfig] = None, | |
exllama_config: Optional[ExllamaConfig] = None, | |
xft_config: Optional[XftConfig] = None, | |
revision: str = "main", | |
judge_sent_end: bool = True, | |
debug: bool = True, | |
history: bool = True, | |
): | |
# Model | |
model, tokenizer = load_model( | |
model_path, | |
device=device, | |
num_gpus=num_gpus, | |
max_gpu_memory=max_gpu_memory, | |
dtype=dtype, | |
load_8bit=load_8bit, | |
cpu_offloading=cpu_offloading, | |
gptq_config=gptq_config, | |
awq_config=awq_config, | |
exllama_config=exllama_config, | |
xft_config=xft_config, | |
revision=revision, | |
debug=debug, | |
) | |
generate_stream_func = get_generate_stream_function(model, model_path) | |
model_type = str(type(model)).lower() | |
is_t5 = "t5" in model_type | |
is_codet5p = "codet5p" in model_type | |
is_xft = "xft" in model_type | |
# Hardcode T5's default repetition penalty to be 1.2 | |
if is_t5 and repetition_penalty == 1.0: | |
repetition_penalty = 1.2 | |
# Set context length | |
context_len = get_context_length(model.config) | |
# Chat | |
def new_chat(): | |
if conv_template: | |
conv = get_conv_template(conv_template) | |
else: | |
conv = get_conversation_template(model_path) | |
if conv_system_msg is not None: | |
conv.set_system_message(conv_system_msg) | |
return conv | |
def reload_conv(conv): | |
""" | |
Reprints the conversation from the start. | |
""" | |
for message in conv.messages[conv.offset :]: | |
chatio.prompt_for_output(message[0]) | |
chatio.print_output(message[1]) | |
conv = None | |
while True: | |
if not history or not conv: | |
conv = new_chat() | |
try: | |
inp = chatio.prompt_for_input(conv.roles[0]) | |
except EOFError: | |
inp = "" | |
if inp == "!!exit" or not inp: | |
print("exit...") | |
break | |
elif inp == "!!reset": | |
print("resetting...") | |
conv = new_chat() | |
continue | |
elif inp == "!!remove": | |
print("removing last message...") | |
if len(conv.messages) > conv.offset: | |
# Assistant | |
if conv.messages[-1][0] == conv.roles[1]: | |
conv.messages.pop() | |
# User | |
if conv.messages[-1][0] == conv.roles[0]: | |
conv.messages.pop() | |
reload_conv(conv) | |
else: | |
print("No messages to remove.") | |
continue | |
elif inp == "!!regen": | |
print("regenerating last message...") | |
if len(conv.messages) > conv.offset: | |
# Assistant | |
if conv.messages[-1][0] == conv.roles[1]: | |
conv.messages.pop() | |
# User | |
if conv.messages[-1][0] == conv.roles[0]: | |
reload_conv(conv) | |
# Set inp to previous message | |
inp = conv.messages.pop()[1] | |
else: | |
# Shouldn't happen in normal circumstances | |
print("No user message to regenerate from.") | |
continue | |
else: | |
print("No messages to regenerate.") | |
continue | |
elif inp.startswith("!!save"): | |
args = inp.split(" ", 1) | |
if len(args) != 2: | |
print("usage: !!save <filename>") | |
continue | |
else: | |
filename = args[1] | |
# Add .json if extension not present | |
if not "." in filename: | |
filename += ".json" | |
print("saving...", filename) | |
with open(filename, "w") as outfile: | |
json.dump(conv.dict(), outfile) | |
continue | |
elif inp.startswith("!!load"): | |
args = inp.split(" ", 1) | |
if len(args) != 2: | |
print("usage: !!load <filename>") | |
continue | |
else: | |
filename = args[1] | |
# Check if file exists and add .json if needed | |
if not os.path.exists(filename): | |
if (not filename.endswith(".json")) and os.path.exists( | |
filename + ".json" | |
): | |
filename += ".json" | |
else: | |
print("file not found:", filename) | |
continue | |
print("loading...", filename) | |
with open(filename, "r") as infile: | |
new_conv = json.load(infile) | |
conv = get_conv_template(new_conv["template_name"]) | |
conv.set_system_message(new_conv["system_message"]) | |
conv.messages = new_conv["messages"] | |
reload_conv(conv) | |
continue | |
conv.append_message(conv.roles[0], inp) | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt() | |
if is_codet5p: # codet5p is a code completion model. | |
prompt = inp | |
gen_params = { | |
"model": model_path, | |
"prompt": prompt, | |
"temperature": temperature, | |
"repetition_penalty": repetition_penalty, | |
"max_new_tokens": max_new_tokens, | |
"stop": conv.stop_str, | |
"stop_token_ids": conv.stop_token_ids, | |
"echo": False, | |
} | |
try: | |
chatio.prompt_for_output(conv.roles[1]) | |
output_stream = generate_stream_func( | |
model, | |
tokenizer, | |
gen_params, | |
device, | |
context_len=context_len, | |
judge_sent_end=judge_sent_end, | |
) | |
t = time.time() | |
outputs = chatio.stream_output(output_stream) | |
duration = time.time() - t | |
conv.update_last_message(outputs.strip()) | |
if debug: | |
num_tokens = len(tokenizer.encode(outputs)) | |
msg = { | |
"conv_template": conv.name, | |
"prompt": prompt, | |
"outputs": outputs, | |
"speed (token/s)": round(num_tokens / duration, 2), | |
} | |
print(f"\n{msg}\n") | |
except KeyboardInterrupt: | |
print("stopped generation.") | |
# If generation didn't finish | |
if conv.messages[-1][1] is None: | |
conv.messages.pop() | |
# Remove last user message, so there isn't a double up | |
if conv.messages[-1][0] == conv.roles[0]: | |
conv.messages.pop() | |
reload_conv(conv) | |