pawlowskipawel commited on
Commit
579b82b
1 Parent(s): ee1f1f1

Add task_prefix_attention_mask argument to _merge_input_ids_with_image_features for better padding handling

Browse files

This PR introduces a small change in the _merge_input_ids_with_image_features function by adding a task_prefix_attention_mask=None argument. This enhancement ensures that when doing batch processing with padding to the max length, the attention mask correctly ignores padding tokens.

Changes Made:
1. Added task_prefix_attention_mask=None argument to _merge_input_ids_with_image_features function.
2. Updated the function to incorporate the provided attention mask, allowing it to ignore padding tokens during batch processing.

Below is an example demonstrating the issue and the improvement:
```python
prompts =["prompt", "longer prompt", "much much longer prompt"]

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"

image = Image.open(requests.get(url, stream=True).raw)
images = [image] * len(prompts)

inputs = processor(text=prompts, images=images, return_tensors="pt", padding=True).to("cuda", torch.float16)

inputs_embeds = model.get_input_embeddings()(inputs.input_ids)
image_features = model._encode_image(inputs.pixel_values)

print(inputs.input_ids)
# Output:
# tensor([[ 0, 12501, 3320, 2, 1, 1],
# [ 0, 3479, 254, 14302, 2, 1],
# [ 0, 28431, 203, 1181, 14302, 2]], device='cuda:0')

# Before change
inputs_embeds, attention_mask = model._merge_input_ids_with_image_features(image_features, inputs_embeds)
print(attention_mask[:, -10:])
# Output:
# tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')

# After change
inputs_embeds, attention_mask = model._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=inputs.attention_mask)
print(attention_mask[:, -10:])
# Output:
# tensor([[1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')
```

Files changed (1) hide show
  1. modeling_florence2.py +13 -8
modeling_florence2.py CHANGED
@@ -2643,7 +2643,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2643
  return x
2644
 
2645
  def _merge_input_ids_with_image_features(
2646
- self, image_features, inputs_embeds
2647
  ):
2648
  batch_size, image_token_length = image_features.size()[:-1]
2649
  device = image_features.device
@@ -2655,10 +2655,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2655
  return image_features, image_attention_mask
2656
 
2657
  task_prefix_embeds = inputs_embeds
2658
- task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
 
 
2659
 
2660
- if len(task_prefix_attention_mask.shape) == 3:
2661
- task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
2662
 
2663
  # concat [image embeds, task prefix embeds]
2664
  inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
@@ -2719,6 +2721,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2719
  >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
2720
  "A green car parked in front of a yellow building."
2721
  ```"""
 
2722
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2723
  output_hidden_states = (
2724
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -2734,8 +2737,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2734
  if pixel_values is not None:
2735
  # (batch_size, num_image_tokens, hidden_size)
2736
  image_features = self._encode_image(pixel_values)
2737
- inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
2738
-
2739
  if inputs_embeds is not None:
2740
  attention_mask = attention_mask.to(inputs_embeds.dtype)
2741
  outputs = self.language_model(
@@ -2781,6 +2784,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2781
  input_ids,
2782
  inputs_embeds=None,
2783
  pixel_values=None,
 
2784
  **kwargs
2785
  ):
2786
 
@@ -2791,11 +2795,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2791
  # 2. Merge text and images
2792
  if pixel_values is not None:
2793
  image_features = self._encode_image(pixel_values)
2794
- inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
2795
 
2796
  return self.language_model.generate(
2797
  input_ids=None,
2798
  inputs_embeds=inputs_embeds,
 
2799
  **kwargs
2800
  )
2801
 
@@ -2844,4 +2849,4 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2844
  return self.language_model.shift_tokens_right(labels)
2845
 
2846
  def _reorder_cache(self, *args, **kwargs):
2847
- return self.language_model._reorder_cache(*args, **kwargs)
 
2643
  return x
2644
 
2645
  def _merge_input_ids_with_image_features(
2646
+ self, image_features, inputs_embeds, task_prefix_attention_mask=None
2647
  ):
2648
  batch_size, image_token_length = image_features.size()[:-1]
2649
  device = image_features.device
 
2655
  return image_features, image_attention_mask
2656
 
2657
  task_prefix_embeds = inputs_embeds
2658
+
2659
+ if task_prefix_attention_mask is None:
2660
+ task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
2661
 
2662
+ if len(task_prefix_attention_mask.shape) == 3:
2663
+ task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
2664
 
2665
  # concat [image embeds, task prefix embeds]
2666
  inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
 
2721
  >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
2722
  "A green car parked in front of a yellow building."
2723
  ```"""
2724
+
2725
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
2726
  output_hidden_states = (
2727
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
2737
  if pixel_values is not None:
2738
  # (batch_size, num_image_tokens, hidden_size)
2739
  image_features = self._encode_image(pixel_values)
2740
+ inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=attention_mask)
2741
+
2742
  if inputs_embeds is not None:
2743
  attention_mask = attention_mask.to(inputs_embeds.dtype)
2744
  outputs = self.language_model(
 
2784
  input_ids,
2785
  inputs_embeds=None,
2786
  pixel_values=None,
2787
+ attention_mask=None,
2788
  **kwargs
2789
  ):
2790
 
 
2795
  # 2. Merge text and images
2796
  if pixel_values is not None:
2797
  image_features = self._encode_image(pixel_values)
2798
+ inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=attention_mask)
2799
 
2800
  return self.language_model.generate(
2801
  input_ids=None,
2802
  inputs_embeds=inputs_embeds,
2803
+ attention_mask=attention_mask,
2804
  **kwargs
2805
  )
2806
 
 
2849
  return self.language_model.shift_tokens_right(labels)
2850
 
2851
  def _reorder_cache(self, *args, **kwargs):
2852
+ return self.language_model._reorder_cache(*args, **kwargs)