Batch Decoding

#3
by vody-am - opened

Hi! Thanks for your charmingly compact model. For batch decoding, I did not see an example but figured out something roughly like:

# Instruct the model to create a caption 
prompt = "caption es"
prompts = [prompt]*4
images = [image]*4
model_inputs = processor(text=prompts, images=images, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
    generation = generation[:, input_len:]
    decoded_batch = processor.batch_decode(generation, skip_special_tokens=True)
    for decoded in decoded_batch:
        print(decoded)

is that correct? An example of batch inference would be helpful!

Thank you.

Google org
edited May 15

It looks good, should work!

merve changed discussion status to closed

Sign up or log in to comment