import time import numpy as np from argparse import ArgumentParser from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B") parser = ArgumentParser() parser.add_argument("--model_path", "--model-path", required=True) parser.add_argument("--prompt", "-p", required=True) parser.add_argument("--max-tokens", "--max_tokens", type=int, default=100) parser.add_argument("--min_p", "--min-p", type=float, default=0.3) parser.add_argument("--temp", type=float, default=1.0) args = parser.parse_args() import coremltools as ct print("Loading model...") if args.model_path.rstrip("/").endswith(".mlpackage"): mf_model_1 = ct.models.MLModel( args.model_path, compute_units=ct.ComputeUnit.CPU_AND_NE, function_name="length_1", ) mf_model_64 = ct.models.MLModel( args.model_path, compute_units=ct.ComputeUnit.CPU_AND_NE, function_name="length_64", ) else: mf_model_1 = ct.models.CompiledMLModel( args.model_path, compute_units=ct.ComputeUnit.CPU_AND_NE, function_name="length_1", ) mf_model_64 = ct.models.CompiledMLModel( args.model_path, compute_units=ct.ComputeUnit.CPU_AND_NE, function_name="length_64", ) def min_p_sample(logits, min_p, temp): # logits = logits.astype(np.float16) max_ = np.max(logits * (1 / temp), axis=1, keepdims=True) logits = logits - max_ logits = np.exp(logits) logits[logits < min_p] = 0 # logits = logits.astype(np.float32) logits = np.cumsum(logits, axis=1) sample = np.random.uniform(high=logits[:, -1:]) sample = np.argmax(logits > sample, axis=1).astype(np.int32) return sample length = len(tokenizer(args.prompt)["input_ids"]) input_ids = tokenizer( args.prompt, return_tensors="np", padding="max_length", max_length=64 )["input_ids"].astype(np.int32) print("Prompt:", args.prompt) state = mf_model_64.make_state() start = time.time() pred = mf_model_64.predict( {"input_ids": input_ids, "query_pos1": np.array([0], dtype=np.int32)}, state ) prompt_time = time.time() - start # input_ids = pred["logits"][..., length - 1].argmax(1, keepdims=True).astype(np.int32) logits = pred["logits"][..., [length - 1]] input_ids = min_p_sample(logits, args.min_p, args.temp) print("Generated:") print(tokenizer.decode(input_ids[0]), end="", flush=True) start = time.time() for i in range(args.max_tokens): pred = mf_model_1.predict( {"input_ids": input_ids, "query_pos1": np.array([i + length], dtype=np.int32)}, state, ) input_ids = min_p_sample(pred["logits"], args.min_p, args.temp) # input_ids = pred["logits"].argmax(1).astype(np.int32) print(tokenizer.decode(input_ids[0]), end="", flush=True) print("", "=" * 10) generation_time = time.time() - start print( "Prompt:", length / prompt_time, "tokens-per-sec", f"({64 / prompt_time} considering the processed padding)", ) print("Generation:", args.max_tokens / generation_time, "tokens-per-sec")