|
import torch |
|
|
|
|
|
@torch.no_grad() |
|
def generate( |
|
vlm, |
|
samples, |
|
use_nucleus_sampling=False, |
|
num_beams=5, |
|
max_length=256, |
|
min_length=1, |
|
top_p=0.9, |
|
repetition_penalty=1.5, |
|
length_penalty=1.0, |
|
num_captions=1, |
|
temperature=1, |
|
): |
|
if "prompt" in samples.keys(): |
|
prompt = samples["prompt"] |
|
else: |
|
prompt = vlm.prompt |
|
|
|
image = samples["image"] |
|
|
|
bs = image.size(0) |
|
|
|
if isinstance(prompt, str): |
|
prompt = [prompt] * bs |
|
else: |
|
assert len(prompt) == bs, "The number of prompts must be equal to the batch size." |
|
|
|
|
|
if "ocr_tokens" in samples.keys() and "{}" in prompt[0]: |
|
prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30])) for i, p in enumerate(prompt)] |
|
|
|
query_tokens = vlm.query_tokens.expand(bs, -1, -1) |
|
if vlm.qformer_text_input: |
|
|
|
|
|
|
|
|
|
text_Qformer = vlm.tokenizer( |
|
prompt, |
|
padding='longest', |
|
truncation=True, |
|
max_length=vlm.max_txt_len, |
|
return_tensors="pt", |
|
).to(image.device) |
|
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device) |
|
Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1) |
|
|
|
|
|
if image.dim() == 5: |
|
inputs_t5, atts_t5 = [], [] |
|
for j in range(image.size(2)): |
|
this_frame = image[:,:,j,:,:] |
|
with vlm.maybe_autocast(): |
|
frame_embeds = vlm.ln_vision(vlm.visual_encoder(this_frame)) |
|
frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device) |
|
|
|
if vlm.qformer_text_input: |
|
frame_query_output = vlm.Qformer.bert( |
|
text_Qformer.input_ids, |
|
attention_mask = Qformer_atts, |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=frame_embeds, |
|
encoder_attention_mask=frame_atts, |
|
return_dict=True, |
|
) |
|
else: |
|
frame_query_output = vlm.Qformer.bert( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=frame_embeds, |
|
encoder_attention_mask=frame_atts, |
|
return_dict=True, |
|
) |
|
|
|
frame_inputs_t5 = vlm.t5_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:]) |
|
frame_atts_t5 = torch.ones(frame_inputs_t5.size()[:-1], dtype=torch.long).to(image.device) |
|
inputs_t5.append(frame_inputs_t5) |
|
atts_t5.append(frame_atts_t5) |
|
inputs_t5 = torch.cat(inputs_t5, dim=1) |
|
atts_t5 = torch.cat(atts_t5, dim=1) |
|
else: |
|
with vlm.maybe_autocast(): |
|
image_embeds = vlm.ln_vision(vlm.visual_encoder(image)) |
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) |
|
|
|
if vlm.qformer_text_input: |
|
query_output = vlm.Qformer.bert( |
|
text_Qformer.input_ids, |
|
attention_mask=Qformer_atts, |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=image_embeds, |
|
encoder_attention_mask=image_atts, |
|
return_dict=True, |
|
) |
|
else: |
|
query_output = vlm.Qformer.bert( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=image_embeds, |
|
encoder_attention_mask=image_atts, |
|
return_dict=True, |
|
) |
|
|
|
inputs_t5 = vlm.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:]) |
|
atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device) |
|
|
|
input_tokens = vlm.t5_tokenizer( |
|
prompt, |
|
padding="longest", |
|
return_tensors="pt" |
|
).to(image.device) |
|
|
|
encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) |
|
|
|
with vlm.maybe_autocast(dtype=torch.bfloat16): |
|
inputs_embeds = vlm.t5_model.encoder.embed_tokens(input_tokens.input_ids) |
|
inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1) |
|
|
|
outputs = vlm.t5_model.generate( |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=encoder_atts, |
|
do_sample=use_nucleus_sampling, |
|
top_p=top_p, |
|
temperature=temperature, |
|
num_beams=num_beams, |
|
max_new_tokens=max_length, |
|
min_length=min_length, |
|
repetition_penalty=repetition_penalty, |
|
length_penalty=length_penalty, |
|
num_return_sequences=num_captions, |
|
) |
|
output_text = vlm.t5_tokenizer.batch_decode( |
|
outputs.sequences, skip_special_tokens=True |
|
) |
|
|
|
return output_text, outputs.sequences_scores |
|
|