seba commited on
Commit
a36f620
1 Parent(s): 4b6bd40

Upload coreml_example.py

Browse files
Files changed (1) hide show
  1. coreml_example.py +214 -0
coreml_example.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import math
3
+ import numpy as np
4
+ from argparse import ArgumentParser
5
+ from transformers import AutoTokenizer
6
+ from dotenv import load_dotenv
7
+ import os
8
+
9
+ load_dotenv()
10
+
11
+ # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
12
+ tokenizer = AutoTokenizer.from_pretrained(
13
+ "meta-llama/Llama-3.2-1B-Instruct", token=os.environ["HF_TOKEN"]
14
+ )
15
+
16
+ parser = ArgumentParser()
17
+ parser.add_argument("--model_path_emb", "--model-path-emb", required=True)
18
+ parser.add_argument("--model_path_mf", "--model-path-mf", required=True)
19
+ # parser.add_argument("--model_path_1", "--model-path-1", required=True)
20
+ # parser.add_argument("--model_path_40", "--model-path-40", required=True)
21
+ parser.add_argument("--model_path_head", "--model-path-head", required=True)
22
+ parser.add_argument("--prompt", "-p", required=True, type=str)
23
+ parser.add_argument("--max-tokens", "--max_tokens", type=int, default=100)
24
+ parser.add_argument("--min_p", "--min-p", type=float, default=0.3)
25
+ parser.add_argument("--temp", type=float, default=1.0)
26
+ args = parser.parse_args()
27
+
28
+ import coremltools as ct
29
+
30
+ print("Loading models...")
31
+
32
+ cu = ct.ComputeUnit.CPU_AND_NE
33
+
34
+
35
+ def load_model(path, fname=None):
36
+ if "mlmodelc" in path:
37
+ return ct.models.CompiledMLModel(path, cu, fname)
38
+ else:
39
+ return ct.models.MLModel(path, cu, function_name=fname)
40
+
41
+
42
+ emb_model = load_model(args.model_path_emb)
43
+ model_1 = load_model(args.model_path_mf, "length_1")
44
+ model_40 = load_model(args.model_path_mf, "length_40")
45
+ model_head = load_model(args.model_path_head)
46
+
47
+ # if args.model_path.rstrip("/").endswith(".mlpackage"):
48
+ # mf_model_1 = ct.models.MLModel(
49
+ # args.model_path,
50
+ # compute_units=ct.ComputeUnit.CPU_AND_NE,
51
+ # function_name="length_1",
52
+ # )
53
+ # mf_model_64 = ct.models.MLModel(
54
+ # args.model_path,
55
+ # compute_units=ct.ComputeUnit.CPU_AND_NE,
56
+ # function_name="length_64",
57
+ # )
58
+ # else:
59
+ # mf_model_1 = ct.models.CompiledMLModel(
60
+ # args.model_path,
61
+ # compute_units=ct.ComputeUnit.CPU_AND_NE,
62
+ # function_name="length_1",
63
+ # )
64
+ # mf_model_64 = ct.models.CompiledMLModel(
65
+ # args.model_path,
66
+ # compute_units=ct.ComputeUnit.CPU_AND_NE,
67
+ # function_name="length_64",
68
+ # )
69
+
70
+ # mf_model_emb = ct.models.MLModel(
71
+ # # args.model_path_emb,
72
+ # "./Llama-3.2-1B-EMB-16Bits.mlpackage",
73
+ # compute_units=ct.ComputeUnit.CPU_AND_NE,
74
+ # # function_name="length_64",
75
+ # )
76
+ # mf_model_mf = ct.models.MLModel(
77
+ # # args.model_path_1,
78
+ # "./Llama-3.2-1B-4bits-MF.mlpackage/",
79
+ # compute_units=ct.ComputeUnit.CPU_AND_NE,
80
+ # # function_name="length_64",
81
+ # )
82
+ # mf_model_40 = ct.models.MLModel(
83
+ # # args.model_path_40,
84
+ # "./Llama-3.2-1B-4bits-CTX-40.mlpackage",
85
+ # compute_units=ct.ComputeUnit.CPU_AND_NE,
86
+ # # function_name="length_64",
87
+ # )
88
+ # head = ct.models.MLModel(
89
+ # # args.model_path_head,
90
+ # "./Llama-3.2-1B-HEAD-6Bits.mlpackage",
91
+ # compute_units=ct.ComputeUnit.CPU_AND_NE,
92
+ # # function_name="length_64",
93
+ # )
94
+
95
+
96
+ def save_compiled(model):
97
+ from shutil import copytree
98
+
99
+ compiled_model_path = model.get_compiled_model_path()
100
+ copytree(
101
+ compiled_model_path,
102
+ model.package_path.replace(".mlpackage", ".mlmodelc"),
103
+ dirs_exist_ok=True,
104
+ )
105
+
106
+
107
+ def min_p_sample(logits, min_p, temp):
108
+ # logits = logits.astype(np.float16)
109
+ max_ = np.max(logits * (1 / temp), axis=1, keepdims=True)
110
+ logits = logits - max_
111
+ logits = np.exp(logits)
112
+ logits[logits < min_p] = 0
113
+ # logits = logits.astype(np.float32)
114
+ logits = np.cumsum(logits, axis=1)
115
+ sample = np.random.uniform(high=logits[:, -1:])
116
+ sample = np.argmax(logits > sample, axis=1).astype(np.int32)
117
+ return sample
118
+
119
+
120
+ def build_causal_mask(seq_length, start, size, end):
121
+ mask = np.full((1, 1, size, seq_length), np.array(-np.inf, dtype=np.float16))
122
+ i, h, j, k = np.indices(mask.shape)
123
+ mask[((k <= (j + start)) & (j < end)) | ((j >= end) & (k == 0))] = (
124
+ 0 # fill first columns with ones to prevent softmax division by 0
125
+ )
126
+ return mask
127
+
128
+
129
+ if tokenizer.pad_token is None:
130
+ tokenizer.pad_token = tokenizer.eos_token
131
+
132
+ mask = build_causal_mask(512, 0, 512, 512)
133
+
134
+ max_length = 40
135
+ # length = len(tokenizer(args.prompt)["input_ids"])
136
+ prompt = [{"role": "user", "content": args.prompt}]
137
+ length = len(tokenizer.apply_chat_template(prompt, add_generation_prompt=True))
138
+ print("Prompt length:", length)
139
+ input_ids = tokenizer.apply_chat_template(
140
+ prompt,
141
+ return_tensors="np",
142
+ padding=True,
143
+ # max_length=max_length,
144
+ return_dict=True,
145
+ add_generation_prompt=True,
146
+ tokenizer_kwargs={
147
+ # "padding": True,
148
+ "pad_to_multiple_of": max_length,
149
+ },
150
+ )["input_ids"].astype(np.int32)
151
+ # input_ids = tokenizer(
152
+ # args.prompt,
153
+ # return_tensors="np",
154
+ # padding="max_length",
155
+ # max_length=max_length,
156
+ # )["input_ids"].astype(np.int32)
157
+ print("Prompt:\n", tokenizer.decode(input_ids[0]))
158
+ state = model_40.make_state()
159
+ start = time.time()
160
+ for i in range(math.ceil(length / max_length)):
161
+ input_embs = emb_model.predict(
162
+ {"input_ids": input_ids[:, i * max_length : (i + 1) * max_length]}
163
+ )["input_embeddings_channels_first"].astype(np.float16)
164
+ pred = model_40.predict(
165
+ {
166
+ "input_ids": input_embs,
167
+ "query_pos1": np.array([i * max_length], dtype=np.int32),
168
+ "mask": mask[:, :, i * max_length : (i + 1) * max_length],
169
+ # "indices": np.array([0], dtype=np.int32),
170
+ "indices": np.arange(i * max_length, (i + 1) * max_length, dtype=np.int32),
171
+ },
172
+ state,
173
+ )
174
+ prompt_time = time.time() - start
175
+ pred = model_head.predict(
176
+ {"hidden_states": pred["final_norm_rmsnorm"][..., [length % max_length - 1]].astype(np.float16)}
177
+ )
178
+ # input_ids = pred["logits"][..., length - 1].argmax(1, keepdims=True).astype(np.int32)
179
+ # logits = pred["logits"][..., [length - 1]]
180
+ logits = pred["concat_0"]
181
+ input_ids = min_p_sample(logits, args.min_p, args.temp)
182
+ print("Generated:")
183
+ print(tokenizer.decode(input_ids[0]), end="", flush=True)
184
+ start = time.time()
185
+ for i in range(args.max_tokens):
186
+ input_embs = emb_model.predict({"input_ids": input_ids})[
187
+ "input_embeddings_channels_first"
188
+ ].astype(np.float16)
189
+ pred = model_1.predict(
190
+ {
191
+ "input_ids": input_embs,
192
+ "query_pos1": np.array([i + length], dtype=np.int32),
193
+ "mask": mask[:, :, [i + length]],
194
+ "indices": np.array([i + length], dtype=np.int32),
195
+ },
196
+ state,
197
+ )
198
+ pred = model_head.predict(
199
+ {"hidden_states": pred["final_norm_rmsnorm"].astype(np.float16)}
200
+ )
201
+ # input_ids = min_p_sample(pred["logits"], args.min_p, args.temp)
202
+ input_ids = min_p_sample(pred["concat_0"], args.min_p, args.temp)
203
+ # input_ids = pred["logits"].argmax(1).astype(np.int32)
204
+ print(tokenizer.decode(input_ids[0]), end="", flush=True)
205
+ print("", "=" * 10)
206
+ generation_time = time.time() - start
207
+
208
+ print(
209
+ "Prompt:",
210
+ length / prompt_time,
211
+ "tokens-per-sec",
212
+ f"({math.ceil(length / max_length) * max_length / prompt_time} considering the processed padding)",
213
+ )
214
+ print("Generation:", args.max_tokens / generation_time, "tokens-per-sec")