|
from __future__ import annotations |
|
|
|
import configparser |
|
import pathlib |
|
import typing |
|
import os |
|
|
|
import torch |
|
import transformers |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
from .config import LYRA_LLAMA_PARAM, LIB_SO_PATH |
|
from .model import LlamaModel |
|
|
|
|
|
class lyraLlama: |
|
def __init__(self, model_path, tokenizer_path=None, dtype='fp16', memopt_mode=0, quant_dtype="int4", kvqparams_fpath="") -> None: |
|
self.model_path = model_path |
|
self.tokenizer_path = tokenizer_path |
|
self.kvqparams_fpath = kvqparams_fpath |
|
|
|
self.dtype = dtype |
|
|
|
self.memopt_mode = memopt_mode |
|
self.quant_data_type = quant_dtype |
|
|
|
self.model, self.tokenizer = self.load_model_and_tokenizer() |
|
print("Got model and tokenizer") |
|
|
|
def load_model_and_tokenizer(self): |
|
if self.tokenizer_path is None: |
|
tokenizer_path = self.model_path |
|
else: |
|
tokenizer_path = self.tokenizer_path |
|
|
|
print(f'Loading tokenizer from {tokenizer_path}') |
|
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_path) |
|
|
|
checkpoint_path = pathlib.Path(self.model_path) |
|
config_path = checkpoint_path / 'config.ini' |
|
|
|
if config_path.exists(): |
|
|
|
cfg = configparser.ConfigParser() |
|
cfg.read(config_path) |
|
model_name = 'llama' |
|
inference_data_type = self.dtype |
|
if inference_data_type == None: |
|
inference_data_type = cfg.get(model_name, "weight_data_type") |
|
model_args = dict( |
|
head_num=cfg.getint(model_name, 'head_num'), |
|
kv_head_num=cfg.getint(model_name, 'kv_head_num', fallback=0), |
|
size_per_head=cfg.getint(model_name, "size_per_head"), |
|
inter_size=cfg.getint(model_name, 'inter_size'), |
|
layer_num=cfg.getint(model_name, "num_layer"), |
|
rotary_embedding_dim=cfg.getint(model_name, 'rotary_embedding'), |
|
layernorm_eps=cfg.getfloat(model_name, 'layernorm_eps'), |
|
vocab_size=cfg.getint(model_name, "vocab_size"), |
|
start_id=cfg.getint(model_name, "start_id"), |
|
end_id=cfg.getint(model_name, "end_id"), |
|
weights_data_type=cfg.get(model_name, "weight_data_type"), |
|
tensor_para_size=cfg.getint(model_name, "tensor_para_size"), |
|
inference_data_type=inference_data_type, |
|
rope_theta=cfg.getfloat(model_name, "rope_theta", fallback=float(10000.0))) |
|
else: |
|
inference_data_type = self.dtype |
|
if inference_data_type == None: |
|
inference_data_type = LYRA_LLAMA_PARAM.weights_data_type |
|
model_args = dict(head_num=LYRA_LLAMA_PARAM.num_heads, |
|
size_per_head=LYRA_LLAMA_PARAM.size_per_head, |
|
inter_size=LYRA_LLAMA_PARAM.inter_size, |
|
layer_num=LYRA_LLAMA_PARAM.num_layers, |
|
rotary_embedding_dim=LYRA_LLAMA_PARAM.rotary_embedding, |
|
layernorm_eps=LYRA_LLAMA_PARAM.layernorm_eps, |
|
vocab_size=LYRA_LLAMA_PARAM.vocab_size, |
|
start_id=LYRA_LLAMA_PARAM.start_id or tokenizer.bos_token_id, |
|
end_id=LYRA_LLAMA_PARAM.end_id or tokenizer.eos_token_id, |
|
weights_data_type=LYRA_LLAMA_PARAM.weights_data_type, |
|
tensor_para_size=LYRA_LLAMA_PARAM.tensor_para_size, |
|
inference_data_type=inference_data_type) |
|
|
|
|
|
model_args.update(dict( |
|
lib_path=LIB_SO_PATH, |
|
model_path=os.path.join(self.model_path, "1-gpu-fp16.bin"), |
|
kvqparams_fpath=self.kvqparams_fpath, |
|
max_seq_len=0, |
|
pipeline_para_size=LYRA_LLAMA_PARAM.pipeline_para_size, |
|
use_gptj_residual=LYRA_LLAMA_PARAM.use_gptj_residual, |
|
memopt_mode=self.memopt_mode, |
|
quant_data_type=self.quant_data_type |
|
|
|
)) |
|
|
|
print('[LYRA][INFO] Load Our LYRA Highly Optimized LLaMA model') |
|
for k, v in model_args.items(): |
|
print(f' - {k.ljust(25, ".")}: {v}') |
|
|
|
|
|
checklist = ['head_num', 'size_per_head', 'vocab_size', 'layer_num', |
|
'tensor_para_size', 'tensor_para_size', 'weights_data_type'] |
|
if None in [model_args[k] for k in checklist]: |
|
none_params = [p for p in checklist if model_args[p] is None] |
|
print(f'[LYRA][WARNING] Found None parameters {none_params}. They must ' |
|
f'be provided either by config file or CLI arguments.') |
|
if model_args['start_id'] != tokenizer.bos_token_id: |
|
print('[LYRA][WARNING] Given start_id is not matched with the bos token ' |
|
'id of the pretrained tokenizer.') |
|
if model_args['end_id'] not in (tokenizer.pad_token_id, tokenizer.eos_token_id): |
|
print('[LYRA][WARNING] Given end_id is not matched with neither pad ' |
|
'token id nor eos token id of the pretrained tokenizer.') |
|
|
|
print(f'Loading model from {self.model_path}') |
|
model = LlamaModel(**model_args) |
|
return model, tokenizer |
|
|
|
def generate(self, prompts: typing.List[str] | str, |
|
output_length: int = 512, |
|
beam_width: int = 1, |
|
top_k: typing.Optional[torch.IntTensor] = 1, |
|
top_p: typing.Optional[torch.FloatTensor] = 1.0, |
|
beam_search_diversity_rate: typing.Optional[torch.FloatTensor] = 0.0, |
|
temperature: typing.Optional[torch.FloatTensor] = 1.0, |
|
len_penalty: typing.Optional[torch.FloatTensor] = 0.0, |
|
repetition_penalty: typing.Optional[torch.FloatTensor] = 1.0, |
|
presence_penalty: typing.Optional[torch.FloatTensor] = None, |
|
min_length: typing.Optional[torch.IntTensor] = None, |
|
bad_words_list: typing.Optional[torch.IntTensor] = None, |
|
do_sample: bool = False, |
|
return_output_length: bool = False, |
|
return_cum_log_probs: int = 0): |
|
if isinstance(prompts, str): |
|
prompts = [prompts, ] |
|
inputs = prompts |
|
|
|
batch_size = len(inputs) |
|
ones_int = torch.ones(size=[batch_size], dtype=torch.int32) |
|
ones_float = torch.ones(size=[batch_size], dtype=torch.float32) |
|
|
|
|
|
input_token_ids = [self.tokenizer(text, return_tensors="pt").input_ids.int().squeeze() for text in inputs] |
|
input_lengths = torch.IntTensor([len(ids) for ids in input_token_ids]) |
|
|
|
input_token_ids = pad_sequence(input_token_ids, batch_first=True, padding_value=self.tokenizer.eos_token_id) |
|
|
|
random_seed = None |
|
if do_sample: |
|
random_seed = torch.randint(0, 262144, (batch_size,), dtype=torch.long) |
|
|
|
outputs = self.model(start_ids=input_token_ids, |
|
start_lengths=input_lengths, |
|
output_len=output_length, |
|
beam_width=beam_width, |
|
top_k=top_k * ones_int, |
|
top_p=top_p * ones_float, |
|
beam_search_diversity_rate=beam_search_diversity_rate * ones_float, |
|
temperature=temperature * ones_float, |
|
len_penalty=len_penalty * ones_float, |
|
repetition_penalty=repetition_penalty * ones_float, |
|
random_seed=random_seed, |
|
return_output_length=return_output_length, |
|
return_cum_log_probs=return_cum_log_probs) |
|
|
|
if return_cum_log_probs > 0: |
|
outputs = outputs[0] |
|
|
|
|
|
|
|
output_token_ids = [out[0, length:].cpu() |
|
for out, length in zip(outputs, input_lengths)] |
|
|
|
output_texts = self.tokenizer.batch_decode( |
|
output_token_ids, skip_special_tokens=True) |
|
|
|
return output_texts |
|
|
|
def stream_generate(self, prompts: typing.List[str] | str, |
|
output_length: int = 512, |
|
beam_width: int = 1, |
|
top_k: typing.Optional[torch.IntTensor] = 1, |
|
top_p: typing.Optional[torch.FloatTensor] = 1.0, |
|
beam_search_diversity_rate: typing.Optional[torch.FloatTensor] = 0.0, |
|
temperature: typing.Optional[torch.FloatTensor] = 1.0, |
|
len_penalty: typing.Optional[torch.FloatTensor] = 0.0, |
|
repetition_penalty: typing.Optional[torch.FloatTensor] = 1.0, |
|
presence_penalty: typing.Optional[torch.FloatTensor] = None, |
|
min_length: typing.Optional[torch.IntTensor] = None, |
|
bad_words_list: typing.Optional[torch.IntTensor] = None, |
|
do_sample: bool = False, |
|
return_output_length: bool = False, |
|
return_cum_log_probs: int = 0): |
|
if isinstance(prompts, str): |
|
prompts = [prompts, ] |
|
|
|
inputs = prompts |
|
|
|
batch_size = len(inputs) |
|
ones_int = torch.ones(size=[batch_size], dtype=torch.int32) |
|
ones_float = torch.ones(size=[batch_size], dtype=torch.float32) |
|
|
|
|
|
input_token_ids = [self.tokenizer(text, return_tensors="pt").input_ids.int().squeeze() for text in inputs] |
|
input_lengths = torch.IntTensor([len(ids) for ids in input_token_ids]) |
|
|
|
input_token_ids = pad_sequence(input_token_ids, batch_first=True, padding_value=self.tokenizer.eos_token_id) |
|
|
|
random_seed = None |
|
if do_sample: |
|
random_seed = torch.randint(0, 262144, (batch_size,), dtype=torch.long) |
|
|
|
for finish, output_ids, sequence_length, output_cum_log_probs in self.model.stream_forward(start_ids=input_token_ids, |
|
start_lengths=input_lengths, |
|
output_len=output_length, |
|
beam_width=beam_width, |
|
top_k=top_k * ones_int, |
|
top_p=top_p * ones_float, |
|
beam_search_diversity_rate=beam_search_diversity_rate * ones_float, |
|
temperature=temperature * ones_float, |
|
len_penalty=len_penalty * ones_float, |
|
repetition_penalty=repetition_penalty * ones_float, |
|
random_seed=random_seed, |
|
return_output_length=return_output_length, |
|
return_cum_log_probs=return_cum_log_probs): |
|
|
|
|
|
|
|
output_token_ids = [out[0, length:].cpu() |
|
for out, length in zip(output_ids, input_lengths)] |
|
output_texts = self.tokenizer.batch_decode( |
|
output_token_ids, skip_special_tokens=True) |
|
|
|
yield finish, output_texts |
|
|