import argparse from time import perf_counter import sys sys.path.append('../') from lyra_baichuan import lyraBaichuan7B, lyraBaichuan13B def get_args(): parser = argparse.ArgumentParser(description="Faster Baichuan Demo") parser.add_argument('--model-path', type=str, required=True, help='Model Path, include config.ini and tokenizer files') # parser.add_argument('--tokenizer-path', type=str, default='/group/30063/users/vanewu/LocalModels/ChatGLM6B-Torch/chatglm-6b') parser.add_argument('--tokenizer-path', type=str, default=None) parser.add_argument( '--data-type', type=str, metavar='TYPE', default='fp16', choices=[None, 'fp32', 'fp16', 'bf16', 'int8'], help='The data type to inference. If None, the data type follows the ' 'checkpoint data type.') parser.add_argument( '--memopt_mode', type=int, default=0, choices=[0, 1], help='Use MEMOPT mode to increase speed and reduce VRAM usage.' ' 0: FP16 mode' ' 1: Use MEMOPT mode') parser.add_argument( '--quant-type', type=str, metavar='TYPE', default='int8', choices=['int4', 'int8'], help='The data type of quantization. Only used in MEMOPT.') parser.add_argument("--prompt", type=str, required=False) parser.add_argument("--max-output-length", type=int, default=512) parser.add_argument("--warmups", type=int, default=10) parser.add_argument("--avgnums", type=int, default=10) args = parser.parse_args() print('\n=================== Arguments ===================') for k, v in vars(args).items(): print(f' - {k.ljust(25, ".")}: {v}') print('=================================================') return args def main(): args = get_args() # model = lyraBaichuan7B(args.model_path, args.tokenizer_path, args.data_type, args.memopt_mode, args.quant_type) model = lyraBaichuan13B(args.model_path, args.tokenizer_path, args.data_type, args.memopt_mode, args.quant_type) # prompt_template = "{}\n" # baichuan chat prompt_template = "{}" # baichuan prompt = prompt_template.format(args.prompt) test_batch_size = [1, 2, 4] # 8, 16, 32, 64 print("test_batch_size: ", test_batch_size) for i, bs in enumerate(test_batch_size): prompts = [prompt, ]*bs # warmup gpu for _ in range(args.warmups): output_texts = model.generate( prompts, output_length=args.max_output_length, top_k=30, top_p=0.85, temperature=1.0, repetition_penalty=1.1, do_sample=False) start = perf_counter() for _ in range(args.avgnums): output_texts = model.generate( prompts, output_length=args.max_output_length, top_k=30, top_p=0.85, temperature=1.0, repetition_penalty=1.0, do_sample=False) end = perf_counter() cost = (end - start) / args.avgnums input_output_texts = [prompt+' ' + gtext for prompt, gtext in zip(prompts, output_texts)] tokens = 0 input_tokens = len(model.tokenizer.encode(prompt)) words = 0 for text in input_output_texts: tokens += len(model.tokenizer.encode(text)) words += len(text) print( f"\nFaster-Dtype: {args.data_type}, Batch Size: {bs}, All tokens: {tokens}. Input tokens: {input_tokens}. Cost: {cost} seconds. Speed: {tokens/cost} tokens/s." ) print( f"Faster-Dtype: {args.data_type}, Batch Size: {bs}, All generated words: {words}. Cost: {cost} seconds. Speed: {words/cost} words/s." ) if i == 0: for k in range(bs): print( f"The {k} Sample, \n\t\tInputs: {prompts[k]}. \n\t\tOutputs: {output_texts[k].lstrip()}") if k>2: break if __name__ == "__main__": main()