Batch Decoding
#3
by
vody-am
- opened
Hi! Thanks for your charmingly compact model. For batch decoding, I did not see an example but figured out something roughly like:
# Instruct the model to create a caption
prompt = "caption es"
prompts = [prompt]*4
images = [image]*4
model_inputs = processor(text=prompts, images=images, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[:, input_len:]
decoded_batch = processor.batch_decode(generation, skip_special_tokens=True)
for decoded in decoded_batch:
print(decoded)
is that correct? An example of batch inference would be helpful!
Thank you.
It looks good, should work!
merve
changed discussion status to
closed