import time import math import numpy as np from argparse import ArgumentParser from transformers import AutoTokenizer from dotenv import load_dotenv import os load_dotenv() # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B") tokenizer = AutoTokenizer.from_pretrained( "meta-llama/Llama-3.2-1B-Instruct", token=os.environ["HF_TOKEN"] ) parser = ArgumentParser() parser.add_argument("--model_path_emb", "--model-path-emb", required=True) parser.add_argument("--model_path_mf", "--model-path-mf", required=True) # parser.add_argument("--model_path_1", "--model-path-1", required=True) # parser.add_argument("--model_path_40", "--model-path-40", required=True) parser.add_argument("--model_path_head", "--model-path-head", required=True) parser.add_argument("--prompt", "-p", required=True, type=str) parser.add_argument("--max-tokens", "--max_tokens", type=int, default=100) parser.add_argument("--min_p", "--min-p", type=float, default=0.3) parser.add_argument("--temp", type=float, default=1.0) args = parser.parse_args() import coremltools as ct print("Loading models...") cu = ct.ComputeUnit.CPU_AND_NE def load_model(path, fname=None): if "mlmodelc" in path: return ct.models.CompiledMLModel(path, cu, fname) else: return ct.models.MLModel(path, cu, function_name=fname) emb_model = load_model(args.model_path_emb) model_1 = load_model(args.model_path_mf, "length_1") model_40 = load_model(args.model_path_mf, "length_40") model_head = load_model(args.model_path_head) # if args.model_path.rstrip("/").endswith(".mlpackage"): # mf_model_1 = ct.models.MLModel( # args.model_path, # compute_units=ct.ComputeUnit.CPU_AND_NE, # function_name="length_1", # ) # mf_model_64 = ct.models.MLModel( # args.model_path, # compute_units=ct.ComputeUnit.CPU_AND_NE, # function_name="length_64", # ) # else: # mf_model_1 = ct.models.CompiledMLModel( # args.model_path, # compute_units=ct.ComputeUnit.CPU_AND_NE, # function_name="length_1", # ) # mf_model_64 = ct.models.CompiledMLModel( # args.model_path, # compute_units=ct.ComputeUnit.CPU_AND_NE, # function_name="length_64", # ) # mf_model_emb = ct.models.MLModel( # # args.model_path_emb, # "./Llama-3.2-1B-EMB-16Bits.mlpackage", # compute_units=ct.ComputeUnit.CPU_AND_NE, # # function_name="length_64", # ) # mf_model_mf = ct.models.MLModel( # # args.model_path_1, # "./Llama-3.2-1B-4bits-MF.mlpackage/", # compute_units=ct.ComputeUnit.CPU_AND_NE, # # function_name="length_64", # ) # mf_model_40 = ct.models.MLModel( # # args.model_path_40, # "./Llama-3.2-1B-4bits-CTX-40.mlpackage", # compute_units=ct.ComputeUnit.CPU_AND_NE, # # function_name="length_64", # ) # head = ct.models.MLModel( # # args.model_path_head, # "./Llama-3.2-1B-HEAD-6Bits.mlpackage", # compute_units=ct.ComputeUnit.CPU_AND_NE, # # function_name="length_64", # ) def save_compiled(model): from shutil import copytree compiled_model_path = model.get_compiled_model_path() copytree( compiled_model_path, model.package_path.replace(".mlpackage", ".mlmodelc"), dirs_exist_ok=True, ) def min_p_sample(logits, min_p, temp): # logits = logits.astype(np.float16) max_ = np.max(logits * (1 / temp), axis=1, keepdims=True) logits = logits - max_ logits = np.exp(logits) logits[logits < min_p] = 0 # logits = logits.astype(np.float32) logits = np.cumsum(logits, axis=1) sample = np.random.uniform(high=logits[:, -1:]) sample = np.argmax(logits > sample, axis=1).astype(np.int32) return sample def build_causal_mask(seq_length, start, size, end): mask = np.full((1, 1, size, seq_length), np.array(-np.inf, dtype=np.float16)) i, h, j, k = np.indices(mask.shape) mask[((k <= (j + start)) & (j < end)) | ((j >= end) & (k == 0))] = ( 0 # fill first columns with ones to prevent softmax division by 0 ) return mask if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token mask = build_causal_mask(512, 0, 512, 512) max_length = 40 # length = len(tokenizer(args.prompt)["input_ids"]) prompt = [{"role": "user", "content": args.prompt}] length = len(tokenizer.apply_chat_template(prompt, add_generation_prompt=True)) print("Prompt length:", length) input_ids = tokenizer.apply_chat_template( prompt, return_tensors="np", padding=True, # max_length=max_length, return_dict=True, add_generation_prompt=True, tokenizer_kwargs={ # "padding": True, "pad_to_multiple_of": max_length, }, )["input_ids"].astype(np.int32) # input_ids = tokenizer( # args.prompt, # return_tensors="np", # padding="max_length", # max_length=max_length, # )["input_ids"].astype(np.int32) print("Prompt:\n", tokenizer.decode(input_ids[0])) state = model_40.make_state() start = time.time() for i in range(math.ceil(length / max_length)): input_embs = emb_model.predict( {"input_ids": input_ids[:, i * max_length : (i + 1) * max_length]} )["input_embeddings_channels_first"].astype(np.float16) pred = model_40.predict( { "input_ids": input_embs, "query_pos1": np.array([i * max_length], dtype=np.int32), "mask": mask[:, :, i * max_length : (i + 1) * max_length], # "indices": np.array([0], dtype=np.int32), "indices": np.arange(i * max_length, (i + 1) * max_length, dtype=np.int32), }, state, ) prompt_time = time.time() - start pred = model_head.predict( {"hidden_states": pred["final_norm_rmsnorm"][..., [length % max_length - 1]].astype(np.float16)} ) # input_ids = pred["logits"][..., length - 1].argmax(1, keepdims=True).astype(np.int32) # logits = pred["logits"][..., [length - 1]] logits = pred["concat_0"] input_ids = min_p_sample(logits, args.min_p, args.temp) print("Generated:") print(tokenizer.decode(input_ids[0]), end="", flush=True) start = time.time() for i in range(args.max_tokens): input_embs = emb_model.predict({"input_ids": input_ids})[ "input_embeddings_channels_first" ].astype(np.float16) pred = model_1.predict( { "input_ids": input_embs, "query_pos1": np.array([i + length], dtype=np.int32), "mask": mask[:, :, [i + length]], "indices": np.array([i + length], dtype=np.int32), }, state, ) pred = model_head.predict( {"hidden_states": pred["final_norm_rmsnorm"].astype(np.float16)} ) # input_ids = min_p_sample(pred["logits"], args.min_p, args.temp) input_ids = min_p_sample(pred["concat_0"], args.min_p, args.temp) # input_ids = pred["logits"].argmax(1).astype(np.int32) print(tokenizer.decode(input_ids[0]), end="", flush=True) print("", "=" * 10) generation_time = time.time() - start print( "Prompt:", length / prompt_time, "tokens-per-sec", f"({math.ceil(length / max_length) * max_length / prompt_time} considering the processed padding)", ) print("Generation:", args.max_tokens / generation_time, "tokens-per-sec")