JohnTan38's picture
Duplicate from chansung/Alpaca-LoRA-Serve
630f532
raw
history blame
6.94 kB
import gc
import copy
import time
from tenacity import RetryError
from tenacity import retry, stop_after_attempt, wait_fixed
import torch
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
LogitsProcessorList,
MinNewTokensLengthLogitsProcessor,
TemperatureLogitsWarper,
TopPLogitsWarper,
MinLengthLogitsProcessor
)
def get_output_batch(
model, tokenizer, prompts, generation_config
):
if len(prompts) == 1:
encoding = tokenizer(prompts, return_tensors="pt")
input_ids = encoding["input_ids"].cuda()
generated_id = model.generate(
input_ids=input_ids,
generation_config=generation_config,
max_new_tokens=256
)
decoded = tokenizer.batch_decode(generated_id)
del input_ids, generated_id
torch.cuda.empty_cache()
return decoded
else:
encodings = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda')
generated_ids = model.generate(
**encodings,
generation_config=generation_config,
max_new_tokens=256
)
decoded = tokenizer.batch_decode(generated_ids)
del encodings, generated_ids
torch.cuda.empty_cache()
return decoded
# StreamModel is borrowed from basaran project
# please find more info about it -> https://github.com/hyperonym/basaran
class StreamModel:
"""StreamModel wraps around a language model to provide stream decoding."""
def __init__(self, model, tokenizer):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = LogitsProcessorList()
self.processor.append(TemperatureLogitsWarper(0.9))
self.processor.append(TopPLogitsWarper(0.75))
def __call__(
self,
prompt,
min_tokens=0,
max_tokens=16,
temperature=1.0,
top_p=1.0,
n=1,
logprobs=0,
):
"""Create a completion stream for the provided prompt."""
input_ids = self.tokenize(prompt)
logprobs = max(logprobs, 0)
# bigger than 1
chunk_size = 2
chunk_count = 0
# Generate completion tokens.
final_tokens = torch.empty(0)
for tokens in self.generate(
input_ids[None, :].repeat(n, 1),
logprobs=logprobs,
min_new_tokens=min_tokens,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
):
if chunk_count < chunk_size:
chunk_count = chunk_count + 1
final_tokens = torch.cat((final_tokens, tokens.to("cpu")))
if chunk_count == chunk_size-1:
chunk_count = 0
yield self.tokenizer.decode(final_tokens, skip_special_tokens=True)
if chunk_count > 0:
yield self.tokenizer.decode(final_tokens, skip_special_tokens=True)
del final_tokens, input_ids
if self.device == "cuda":
torch.cuda.empty_cache()
def _infer(self, model_fn, **kwargs):
with torch.inference_mode():
return model_fn(**kwargs)
def tokenize(self, text):
"""Tokenize a string into a tensor of token IDs."""
batch = self.tokenizer.encode(text, return_tensors="pt")
return batch[0].to(self.device)
def generate(self, input_ids, logprobs=0, **kwargs):
"""Generate a stream of predicted tokens using the language model."""
# Store the original batch size and input length.
batch_size = input_ids.shape[0]
input_length = input_ids.shape[-1]
# Separate model arguments from generation config.
config = self.model.generation_config
config = copy.deepcopy(config)
kwargs = config.update(**kwargs)
kwargs["output_attentions"] = False
kwargs["output_hidden_states"] = False
kwargs["use_cache"] = True
# Collect special token IDs.
pad_token_id = config.pad_token_id
bos_token_id = config.bos_token_id
eos_token_id = config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if pad_token_id is None and eos_token_id is not None:
pad_token_id = eos_token_id[0]
# Generate from eos if no input is specified.
if input_length == 0:
input_ids = input_ids.new_ones((batch_size, 1)).long()
if eos_token_id is not None:
input_ids = input_ids * eos_token_id[0]
input_length = 1
# Keep track of which sequences are already finished.
unfinished = input_ids.new_ones(batch_size)
# Start auto-regressive generation.
while True:
inputs = self.model.prepare_inputs_for_generation(
input_ids, **kwargs
) # noqa: E501
outputs = self._infer(
self.model,
**inputs,
# return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
# Pre-process the probability distribution of the next tokens.
logits = outputs.logits[:, -1, :]
with torch.inference_mode():
logits = self.processor(input_ids, logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
# Select deterministic or stochastic decoding strategy.
if (config.top_p is not None and config.top_p <= 0) or (
config.temperature is not None and config.temperature <= 0
):
tokens = torch.argmax(probs, dim=-1)[:, None]
else:
tokens = torch.multinomial(probs, num_samples=1)
tokens = tokens.squeeze(1)
# Finished sequences should have their next token be a padding.
if pad_token_id is not None:
tokens = tokens * unfinished + pad_token_id * (1 - unfinished)
# Append selected tokens to the inputs.
input_ids = torch.cat([input_ids, tokens[:, None]], dim=-1)
# Mark sequences with eos tokens as finished.
if eos_token_id is not None:
not_eos = sum(tokens != i for i in eos_token_id)
unfinished = unfinished.mul(not_eos.long())
# Set status to -1 if exceeded the max length.
status = unfinished.clone()
if input_ids.shape[-1] - input_length >= config.max_new_tokens:
status = 0 - status
# Yield predictions and status.
yield tokens
# Stop when finished or exceeded the max length.
if status.max() <= 0:
break