jadechoghari commited on
Commit
97ceae8
1 Parent(s): 753f935

Create modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +168 -0
modeling.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ try:
22
+ from transformers import AutoConfig, AutoModelForCausalLM, \
23
+ GemmaConfig, GemmaModel, GemmaForCausalLM
24
+ except:
25
+ print("New model not imported. Try to update Transformers to 4.38.0 or later.")
26
+ from transformers.modeling_outputs import CausalLMOutputWithPast
27
+ from transformers.generation.utils import GenerateOutput
28
+ from transformers.generation.utils import logging
29
+
30
+ from ..ferret_arch import FerretMetaModel, FerretMetaForCausalLM
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ class FerretGemmaConfig(GemmaConfig):
35
+ model_type = "ferret_gemma"
36
+
37
+
38
+ class FerretGemmaModel(FerretMetaModel, GemmaModel):
39
+ config_class = FerretGemmaConfig
40
+
41
+ def __init__(self, config: GemmaConfig):
42
+ super(FerretGemmaModel, self).__init__(config)
43
+
44
+
45
+ class FerretGemmaForCausalLM(GemmaForCausalLM, FerretMetaForCausalLM):
46
+ config_class = FerretGemmaConfig
47
+
48
+ def __init__(self, config):
49
+ super(GemmaForCausalLM, self).__init__(config)
50
+ self.model = FerretGemmaModel(config)
51
+ self.vocab_size = config.vocab_size
52
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
53
+
54
+ # Initialize weights and apply final processing
55
+ self.post_init()
56
+
57
+ def get_model(self):
58
+ return self.model
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: torch.LongTensor = None,
63
+ attention_mask: Optional[torch.Tensor] = None,
64
+ position_ids: Optional[torch.LongTensor] = None,
65
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
66
+ inputs_embeds: Optional[torch.FloatTensor] = None,
67
+ labels: Optional[torch.LongTensor] = None,
68
+ use_cache: Optional[bool] = None,
69
+ cache_position: Optional[torch.LongTensor] = None,
70
+ output_attentions: Optional[bool] = None,
71
+ output_hidden_states: Optional[bool] = None,
72
+ images: Optional[torch.FloatTensor] = None,
73
+ image_sizes: Optional[List[List[int]]] = None,
74
+ region_masks: Optional[List[torch.Tensor]] = None,
75
+ return_dict: Optional[bool] = None,
76
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
77
+
78
+ if inputs_embeds is None:
79
+ (
80
+ input_ids,
81
+ position_ids,
82
+ attention_mask,
83
+ past_key_values,
84
+ inputs_embeds,
85
+ labels,
86
+ ) = self.prepare_inputs_labels_for_multimodal(
87
+ input_ids,
88
+ position_ids,
89
+ attention_mask,
90
+ past_key_values,
91
+ labels,
92
+ images,
93
+ image_sizes=image_sizes,
94
+ region_masks=region_masks,
95
+ )
96
+
97
+ forward_output = super().forward(
98
+ input_ids=input_ids,
99
+ attention_mask=attention_mask,
100
+ position_ids=position_ids,
101
+ past_key_values=past_key_values,
102
+ inputs_embeds=inputs_embeds,
103
+ labels=labels,
104
+ use_cache=use_cache,
105
+ cache_position=cache_position,
106
+ output_attentions=output_attentions,
107
+ output_hidden_states=output_hidden_states,
108
+ return_dict=return_dict
109
+ )
110
+
111
+ return forward_output
112
+
113
+ @torch.no_grad()
114
+ def generate(
115
+ self,
116
+ inputs: Optional[torch.Tensor] = None,
117
+ images: Optional[torch.Tensor] = None,
118
+ image_sizes: Optional[torch.Tensor] = None,
119
+ region_masks: Optional[List[torch.Tensor]] = None,
120
+ **kwargs,
121
+ ) -> Union[GenerateOutput, torch.LongTensor]:
122
+ position_ids = kwargs.pop("position_ids", None)
123
+ attention_mask = kwargs.pop("attention_mask", None)
124
+ if "inputs_embeds" in kwargs:
125
+ raise NotImplementedError("`inputs_embeds` is not supported")
126
+
127
+ if images is not None:
128
+ (
129
+ inputs,
130
+ position_ids,
131
+ attention_mask,
132
+ _,
133
+ inputs_embeds,
134
+ _
135
+ ) = self.prepare_inputs_labels_for_multimodal(
136
+ inputs,
137
+ position_ids,
138
+ attention_mask,
139
+ None,
140
+ None,
141
+ images,
142
+ image_sizes=image_sizes,
143
+ region_masks=region_masks,
144
+ )
145
+ else:
146
+ inputs_embeds = self.get_model().embed_tokens(inputs)
147
+
148
+ return super().generate(
149
+ position_ids=position_ids,
150
+ attention_mask=attention_mask,
151
+ inputs_embeds=inputs_embeds,
152
+ **kwargs
153
+ )
154
+
155
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
156
+ images = kwargs.pop("images", None)
157
+ image_sizes = kwargs.pop("image_sizes", None)
158
+ inputs = super().prepare_inputs_for_generation(
159
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
160
+ )
161
+ if images is not None:
162
+ inputs['images'] = images
163
+ if image_sizes is not None:
164
+ inputs['image_sizes'] = image_sizes
165
+ return inputs
166
+
167
+ AutoConfig.register("ferret_gemma", FerretGemmaConfig)
168
+ AutoModelForCausalLM.register(FerretGemmaConfig, FerretGemmaForCausalLM)