tomer-deci commited on
Commit
24a66a7
1 Parent(s): 60965be

Upload benchmark_hf_model.py

Browse files
Files changed (1) hide show
  1. benchmark_hf_model.py +138 -0
benchmark_hf_model.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from argparse import ArgumentParser
3
+
4
+ import datasets
5
+ import torch
6
+ import transformers
7
+ from transformers import AutoModelForCausalLM, BatchEncoding
8
+
9
+ """
10
+ Usage examples (with the best batch sizes on A100-80GB-400W)
11
+ ============================================================
12
+ python -m benchmark_hf_model --model_name_or_path="Deci/DeciLM-7B" --batch_size=352
13
+ python -m benchmark_hf_model --model_name_or_path="mistralai/Mistral-7B-v0.1" --batch_size=192 --model_kwargs_json='{"use_flash_attention_2": true}'
14
+ python -m benchmark_hf_model --model_name_or_path="meta-llama/Llama-2-7b-hf" --batch_size=48 --model_kwargs_json='{"use_flash_attention_2": true}'
15
+ """
16
+
17
+
18
+ def parse_args():
19
+ parser = ArgumentParser()
20
+
21
+ parser.add_argument(
22
+ "--model_name_or_path",
23
+ type=str,
24
+ required=True,
25
+ )
26
+ parser.add_argument(
27
+ "--warmup_iters",
28
+ type=int,
29
+ default=10,
30
+ )
31
+ parser.add_argument(
32
+ "--iterations",
33
+ type=int,
34
+ default=5,
35
+ )
36
+ parser.add_argument(
37
+ "--batch_size",
38
+ type=int,
39
+ default=32,
40
+ )
41
+ parser.add_argument(
42
+ "--prompt_length",
43
+ type=int,
44
+ default=512,
45
+ )
46
+ parser.add_argument(
47
+ "--max_new_tokens",
48
+ type=int,
49
+ default=512,
50
+ )
51
+ parser.add_argument(
52
+ "--precision",
53
+ type=str,
54
+ default="bf16",
55
+ help="Model precision, from: fp32, fp16 or bf16",
56
+ )
57
+ parser.add_argument(
58
+ "--model_kwargs_json",
59
+ type=str,
60
+ default=None,
61
+ )
62
+ return parser.parse_args()
63
+
64
+
65
+ def main():
66
+ args = parse_args()
67
+ transformers.logging.set_verbosity_error()
68
+ datasets.logging.set_verbosity_error()
69
+
70
+ dict_precisions = {
71
+ "fp32": torch.float32,
72
+ "fp16": torch.float16,
73
+ "bf16": torch.bfloat16,
74
+ }
75
+ if args.precision not in dict_precisions:
76
+ raise ValueError(
77
+ f"Non valid precision {args.precision}, choose from: fp16, fp32, bf16"
78
+ )
79
+ dtype = dict_precisions[args.precision]
80
+
81
+ model_kwargs = {}
82
+ if args.model_kwargs_json is not None:
83
+ model_kwargs = json.loads(args.model_kwargs_json)
84
+
85
+ print(f"loading model...")
86
+ model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True,
87
+ torch_dtype=dtype, **model_kwargs)
88
+ try:
89
+ print(model.model.layers[0].self_attn)
90
+ except:
91
+ print("couldn't print the model's attention module")
92
+
93
+ starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
94
+ model.cuda()
95
+ model.eval()
96
+
97
+ prompt = torch.ones(args.prompt_length, dtype=torch.long)
98
+ inputs = BatchEncoding({"input_ids": prompt.repeat(args.batch_size, 1)})
99
+ inputs = inputs.to(model.device)
100
+
101
+ # warmup
102
+ print(f"warming up for {args.warmup_iters} iterations...")
103
+ for _ in range(args.warmup_iters):
104
+ with torch.no_grad():
105
+ _ = model.generate(
106
+ **inputs,
107
+ max_new_tokens=1,
108
+ do_sample=False,
109
+ eos_token_id=-1234,
110
+ )
111
+ print('finished warmup')
112
+ torch.cuda.synchronize()
113
+
114
+ print(
115
+ f"prefill ({args.prompt_length} tokens{f' x {args.batch_size} batch' if args.batch_size > 1 else ''}) + generation ({args.max_new_tokens} tokens{f' x {args.batch_size} batch' if args.batch_size > 1 else ''}):")
116
+ tokens_generated = args.max_new_tokens * args.batch_size
117
+ prefill_and_generation = []
118
+ for gen_iter in range(args.iterations):
119
+ starter.record()
120
+ with torch.no_grad():
121
+ _ = model.generate(
122
+ **inputs,
123
+ max_new_tokens=args.max_new_tokens,
124
+ do_sample=False,
125
+ eos_token_id=-1234,
126
+ )
127
+ ender.record()
128
+ torch.cuda.synchronize()
129
+ t = starter.elapsed_time(ender) / 1000
130
+ prefill_and_generation.append(t)
131
+ print(f" iter {gen_iter + 1}: {t:.03f} sec total, {tokens_generated / t:.02f} generated tokens/sec")
132
+ aver = sum(prefill_and_generation) / len(prefill_and_generation)
133
+ print(f" average: {aver:.03f} sec total, {tokens_generated / aver:.02f} generated tokens/sec")
134
+ print(f"These results are obtained for model '{args.model_name_or_path}' with {args.batch_size=}.")
135
+
136
+
137
+ if __name__ == "__main__":
138
+ main()