Spaces:
Runtime error
Runtime error
import gc | |
import copy | |
import time | |
from tenacity import RetryError | |
from tenacity import retry, stop_after_attempt, wait_fixed | |
import torch | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoModelForSeq2SeqLM, | |
AutoTokenizer, | |
LogitsProcessorList, | |
MinNewTokensLengthLogitsProcessor, | |
TemperatureLogitsWarper, | |
TopPLogitsWarper, | |
MinLengthLogitsProcessor | |
) | |
def get_output_batch( | |
model, tokenizer, prompts, generation_config | |
): | |
if len(prompts) == 1: | |
encoding = tokenizer(prompts, return_tensors="pt") | |
input_ids = encoding["input_ids"].cuda() | |
generated_id = model.generate( | |
input_ids=input_ids, | |
generation_config=generation_config, | |
max_new_tokens=256 | |
) | |
decoded = tokenizer.batch_decode(generated_id) | |
del input_ids, generated_id | |
torch.cuda.empty_cache() | |
return decoded | |
else: | |
encodings = tokenizer(prompts, padding=True, return_tensors="pt").to('cuda') | |
generated_ids = model.generate( | |
**encodings, | |
generation_config=generation_config, | |
max_new_tokens=256 | |
) | |
decoded = tokenizer.batch_decode(generated_ids) | |
del encodings, generated_ids | |
torch.cuda.empty_cache() | |
return decoded | |
# StreamModel is borrowed from basaran project | |
# please find more info about it -> https://github.com/hyperonym/basaran | |
class StreamModel: | |
"""StreamModel wraps around a language model to provide stream decoding.""" | |
def __init__(self, model, tokenizer): | |
super().__init__() | |
self.model = model | |
self.tokenizer = tokenizer | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.processor = LogitsProcessorList() | |
self.processor.append(TemperatureLogitsWarper(0.9)) | |
self.processor.append(TopPLogitsWarper(0.75)) | |
def __call__( | |
self, | |
prompt, | |
min_tokens=0, | |
max_tokens=16, | |
temperature=1.0, | |
top_p=1.0, | |
n=1, | |
logprobs=0, | |
): | |
"""Create a completion stream for the provided prompt.""" | |
input_ids = self.tokenize(prompt) | |
logprobs = max(logprobs, 0) | |
# bigger than 1 | |
chunk_size = 2 | |
chunk_count = 0 | |
# Generate completion tokens. | |
final_tokens = torch.empty(0) | |
for tokens in self.generate( | |
input_ids[None, :].repeat(n, 1), | |
logprobs=logprobs, | |
min_new_tokens=min_tokens, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
if chunk_count < chunk_size: | |
chunk_count = chunk_count + 1 | |
final_tokens = torch.cat((final_tokens, tokens.to("cpu"))) | |
if chunk_count == chunk_size-1: | |
chunk_count = 0 | |
yield self.tokenizer.decode(final_tokens, skip_special_tokens=True) | |
if chunk_count > 0: | |
yield self.tokenizer.decode(final_tokens, skip_special_tokens=True) | |
del final_tokens, input_ids | |
if self.device == "cuda": | |
torch.cuda.empty_cache() | |
def _infer(self, model_fn, **kwargs): | |
with torch.inference_mode(): | |
return model_fn(**kwargs) | |
def tokenize(self, text): | |
"""Tokenize a string into a tensor of token IDs.""" | |
batch = self.tokenizer.encode(text, return_tensors="pt") | |
return batch[0].to(self.device) | |
def generate(self, input_ids, logprobs=0, **kwargs): | |
"""Generate a stream of predicted tokens using the language model.""" | |
# Store the original batch size and input length. | |
batch_size = input_ids.shape[0] | |
input_length = input_ids.shape[-1] | |
# Separate model arguments from generation config. | |
config = self.model.generation_config | |
config = copy.deepcopy(config) | |
kwargs = config.update(**kwargs) | |
kwargs["output_attentions"] = False | |
kwargs["output_hidden_states"] = False | |
kwargs["use_cache"] = True | |
# Collect special token IDs. | |
pad_token_id = config.pad_token_id | |
bos_token_id = config.bos_token_id | |
eos_token_id = config.eos_token_id | |
if isinstance(eos_token_id, int): | |
eos_token_id = [eos_token_id] | |
if pad_token_id is None and eos_token_id is not None: | |
pad_token_id = eos_token_id[0] | |
# Generate from eos if no input is specified. | |
if input_length == 0: | |
input_ids = input_ids.new_ones((batch_size, 1)).long() | |
if eos_token_id is not None: | |
input_ids = input_ids * eos_token_id[0] | |
input_length = 1 | |
# Keep track of which sequences are already finished. | |
unfinished = input_ids.new_ones(batch_size) | |
# Start auto-regressive generation. | |
while True: | |
inputs = self.model.prepare_inputs_for_generation( | |
input_ids, **kwargs | |
) # noqa: E501 | |
outputs = self._infer( | |
self.model, | |
**inputs, | |
# return_dict=True, | |
output_attentions=False, | |
output_hidden_states=False, | |
) | |
# Pre-process the probability distribution of the next tokens. | |
logits = outputs.logits[:, -1, :] | |
with torch.inference_mode(): | |
logits = self.processor(input_ids, logits) | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
# Select deterministic or stochastic decoding strategy. | |
if (config.top_p is not None and config.top_p <= 0) or ( | |
config.temperature is not None and config.temperature <= 0 | |
): | |
tokens = torch.argmax(probs, dim=-1)[:, None] | |
else: | |
tokens = torch.multinomial(probs, num_samples=1) | |
tokens = tokens.squeeze(1) | |
# Finished sequences should have their next token be a padding. | |
if pad_token_id is not None: | |
tokens = tokens * unfinished + pad_token_id * (1 - unfinished) | |
# Append selected tokens to the inputs. | |
input_ids = torch.cat([input_ids, tokens[:, None]], dim=-1) | |
# Mark sequences with eos tokens as finished. | |
if eos_token_id is not None: | |
not_eos = sum(tokens != i for i in eos_token_id) | |
unfinished = unfinished.mul(not_eos.long()) | |
# Set status to -1 if exceeded the max length. | |
status = unfinished.clone() | |
if input_ids.shape[-1] - input_length >= config.max_new_tokens: | |
status = 0 - status | |
# Yield predictions and status. | |
yield tokens | |
# Stop when finished or exceeded the max length. | |
if status.max() <= 0: | |
break | |