Add task_prefix_attention_mask argument to _merge_input_ids_with_image_features for better padding handling
Browse filesThis PR introduces a small change in the _merge_input_ids_with_image_features function by adding a task_prefix_attention_mask=None argument. This enhancement ensures that when doing batch processing with padding to the max length, the attention mask correctly ignores padding tokens.
Changes Made:
1. Added task_prefix_attention_mask=None argument to _merge_input_ids_with_image_features function.
2. Updated the function to incorporate the provided attention mask, allowing it to ignore padding tokens during batch processing.
Below is an example demonstrating the issue and the improvement:
```python
prompts =["prompt", "longer prompt", "much much longer prompt"]
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
images = [image] * len(prompts)
inputs = processor(text=prompts, images=images, return_tensors="pt", padding=True).to("cuda", torch.float16)
inputs_embeds = model.get_input_embeddings()(inputs.input_ids)
image_features = model._encode_image(inputs.pixel_values)
print(inputs.input_ids)
# Output:
# tensor([[ 0, 12501, 3320, 2, 1, 1],
# [ 0, 3479, 254, 14302, 2, 1],
# [ 0, 28431, 203, 1181, 14302, 2]], device='cuda:0')
# Before change
inputs_embeds, attention_mask = model._merge_input_ids_with_image_features(image_features, inputs_embeds)
print(attention_mask[:, -10:])
# Output:
# tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')
# After change
inputs_embeds, attention_mask = model._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=inputs.attention_mask)
print(attention_mask[:, -10:])
# Output:
# tensor([[1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')
```
- modeling_florence2.py +13 -8
@@ -2643,7 +2643,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2643 |
return x
|
2644 |
|
2645 |
def _merge_input_ids_with_image_features(
|
2646 |
-
self, image_features, inputs_embeds
|
2647 |
):
|
2648 |
batch_size, image_token_length = image_features.size()[:-1]
|
2649 |
device = image_features.device
|
@@ -2655,10 +2655,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2655 |
return image_features, image_attention_mask
|
2656 |
|
2657 |
task_prefix_embeds = inputs_embeds
|
2658 |
-
|
|
|
|
|
2659 |
|
2660 |
-
|
2661 |
-
|
2662 |
|
2663 |
# concat [image embeds, task prefix embeds]
|
2664 |
inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
|
@@ -2719,6 +2721,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2719 |
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
2720 |
"A green car parked in front of a yellow building."
|
2721 |
```"""
|
|
|
2722 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
2723 |
output_hidden_states = (
|
2724 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
@@ -2734,8 +2737,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2734 |
if pixel_values is not None:
|
2735 |
# (batch_size, num_image_tokens, hidden_size)
|
2736 |
image_features = self._encode_image(pixel_values)
|
2737 |
-
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
2738 |
-
|
2739 |
if inputs_embeds is not None:
|
2740 |
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
2741 |
outputs = self.language_model(
|
@@ -2781,6 +2784,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2781 |
input_ids,
|
2782 |
inputs_embeds=None,
|
2783 |
pixel_values=None,
|
|
|
2784 |
**kwargs
|
2785 |
):
|
2786 |
|
@@ -2791,11 +2795,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2791 |
# 2. Merge text and images
|
2792 |
if pixel_values is not None:
|
2793 |
image_features = self._encode_image(pixel_values)
|
2794 |
-
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
2795 |
|
2796 |
return self.language_model.generate(
|
2797 |
input_ids=None,
|
2798 |
inputs_embeds=inputs_embeds,
|
|
|
2799 |
**kwargs
|
2800 |
)
|
2801 |
|
@@ -2844,4 +2849,4 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2844 |
return self.language_model.shift_tokens_right(labels)
|
2845 |
|
2846 |
def _reorder_cache(self, *args, **kwargs):
|
2847 |
-
return self.language_model._reorder_cache(*args, **kwargs)
|
|
|
2643 |
return x
|
2644 |
|
2645 |
def _merge_input_ids_with_image_features(
|
2646 |
+
self, image_features, inputs_embeds, task_prefix_attention_mask=None
|
2647 |
):
|
2648 |
batch_size, image_token_length = image_features.size()[:-1]
|
2649 |
device = image_features.device
|
|
|
2655 |
return image_features, image_attention_mask
|
2656 |
|
2657 |
task_prefix_embeds = inputs_embeds
|
2658 |
+
|
2659 |
+
if task_prefix_attention_mask is None:
|
2660 |
+
task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
|
2661 |
|
2662 |
+
if len(task_prefix_attention_mask.shape) == 3:
|
2663 |
+
task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
|
2664 |
|
2665 |
# concat [image embeds, task prefix embeds]
|
2666 |
inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
|
|
|
2721 |
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
2722 |
"A green car parked in front of a yellow building."
|
2723 |
```"""
|
2724 |
+
|
2725 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
2726 |
output_hidden_states = (
|
2727 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
2737 |
if pixel_values is not None:
|
2738 |
# (batch_size, num_image_tokens, hidden_size)
|
2739 |
image_features = self._encode_image(pixel_values)
|
2740 |
+
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=attention_mask)
|
2741 |
+
|
2742 |
if inputs_embeds is not None:
|
2743 |
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
2744 |
outputs = self.language_model(
|
|
|
2784 |
input_ids,
|
2785 |
inputs_embeds=None,
|
2786 |
pixel_values=None,
|
2787 |
+
attention_mask=None,
|
2788 |
**kwargs
|
2789 |
):
|
2790 |
|
|
|
2795 |
# 2. Merge text and images
|
2796 |
if pixel_values is not None:
|
2797 |
image_features = self._encode_image(pixel_values)
|
2798 |
+
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=attention_mask)
|
2799 |
|
2800 |
return self.language_model.generate(
|
2801 |
input_ids=None,
|
2802 |
inputs_embeds=inputs_embeds,
|
2803 |
+
attention_mask=attention_mask,
|
2804 |
**kwargs
|
2805 |
)
|
2806 |
|
|
|
2849 |
return self.language_model.shift_tokens_right(labels)
|
2850 |
|
2851 |
def _reorder_cache(self, *args, **kwargs):
|
2852 |
+
return self.language_model._reorder_cache(*args, **kwargs)
|