Batch inference seems to be done sequentially
When using a batch size larger than 1, the generation time increases almost linearly with the batch size. This is highly unexpected and not something I have seen with other transformers. I would expect a transformer model to handle batched inputs without noticeable impact on latency.
Script:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import transformers
import torch
import deepspeed
import time
from deepspeed.accelerator import get_accelerator
model = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model, trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16
)
batch_size = 1
input_prompt = [
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
] * batch_size
input_tokens = tokenizer.batch_encode_plus(
input_prompt,
return_tensors="pt",
)
token_num = input_tokens["input_ids"].size(-1)
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to(model.device)
input_tokens.pop("token_type_ids")
# Warmup
print(f"Batch size {batch_size}")
sequences = model.generate(
**input_tokens, min_length=512, max_length=512, do_sample=True
)
torch.cuda.synchronize()
st = time.monotonic()
for i in range(2):
torch.cuda.synchronize()
sequences = model.generate(
**input_tokens, min_length=512, max_length=512, do_sample=True
)
torch.cuda.synchronize()
tt = time.monotonic() - st
print(f"Time taken {tt/2} time per token {tt/512/2}")
Results with different batch_size
:
BS 1: Time taken 20.67650790150003 time per token 0.04038380449511725
BS 2: Time taken 32.592279224000094 time per token 0.06365679535937518
BS 4: Time taken 48.25992262649993 time per token 0.09425766137988267
BS 8: Time taken 86.17116434899981 time per token 0.16830305536914025
That's the time it takes to process the entire batch and increases with the number of samples. What you expect to decrease is the time per sample, which indeed decreases when I look at your numbers
That's the time it takes to process the entire batch and increases with the number of samples. What you expect to decrease is the time per sample, which indeed decreases when I look at your numbers
This is bullshit g-ronimo
be nice.
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch
import time
model = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.bfloat16
)
for batch_size in [2, 4, 8]:
input_prompt = [
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
] * batch_size
input_tokens = tokenizer.batch_encode_plus(
input_prompt,
return_tensors="pt",
).to("cuda")
input_tokens_cnt = sum([len(t) for t in input_tokens["input_ids"]])
# Warmup
sequences = model.generate(
**input_tokens, min_length=512, max_length=512, do_sample=True
)
torch.cuda.synchronize()
st = time.monotonic()
generated_tokens_count = []
num_trials = 2
for i in range(num_trials):
torch.cuda.synchronize()
sequences = model.generate(
**input_tokens,
min_length=512,
max_length=512,
do_sample=True
)
torch.cuda.synchronize()
sequences_tokens_cnt = sum([len(t) for t in sequences])
generated_tokens_count.append(sequences_tokens_cnt - input_tokens_cnt)
tt = time.monotonic() - st
print(f"batch_size {batch_size}: Avg. time taken {tt/num_trials}, avg. time per token {tt/sum(generated_tokens_count)}")
output
batch_size 2: Avg. time taken 17.3977282285, avg. time per token 0.019287947038248338
batch_size 4: Avg. time taken 17.282605756000066, avg. time per token 0.009580158401330413
batch_size 8: Avg. time taken 18.542240016500045, avg. time per token 0.005139201778409103