Galuh Sahid commited on
Commit
653217a
1 Parent(s): 588c4c1
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Indonesian Image Captioning
3
- emoji: 💩
4
  colorFrom: yellow
5
  colorTo: yellow
6
  sdk: streamlit
 
1
  ---
2
  title: Indonesian Image Captioning
3
+ emoji: 🖼️
4
  colorFrom: yellow
5
  colorTo: yellow
6
  sdk: streamlit
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from PIL import Image
3
+ from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize, ToTensor, Compose
4
+ from torchvision.transforms.functional import InterpolationMode
5
+ import torch
6
+ import numpy as np
7
+ from transformers import MarianTokenizer
8
+ from flax_clip_vision_marian.modeling_clip_vision_marian import FlaxCLIPVisionMarianForConditionalGeneration
9
+ import logging
10
+ import streamlit as st
11
+ from mtranslate import translate
12
+
13
+ class CaptionGenerator:
14
+ def __init__(self):
15
+ self.tokenizer = None
16
+ self.clip_marian_model = None
17
+ self.marian_model_name = 'Helsinki-NLP/opus-mt-en-id'
18
+ self.clip_marian_model_name = 'flax-community/Image-captioning-Indonesia'
19
+
20
+ self.config = None
21
+ self.image_size = None
22
+ self.custom_transforms = None
23
+
24
+ def load(self):
25
+ logging.info("Loading tokenizer...")
26
+ marian_model_name = 'Helsinki-NLP/opus-mt-en-id'
27
+ self.tokenizer = MarianTokenizer.from_pretrained(self.marian_model_name)
28
+ logging.info("Tokenizer loaded.")
29
+
30
+ logging.info("Loading model...")
31
+ self.model = FlaxCLIPVisionMarianForConditionalGeneration.from_pretrained(self.clip_marian_model_name)
32
+ logging.info("Model loaded.")
33
+
34
+ self.config = self.model.config
35
+ self.image_size = self.config.clip_vision_config.image_size
36
+
37
+ self.custom_transforms = torch.nn.Sequential(
38
+ Resize([self.image_size], interpolation=InterpolationMode.BICUBIC),
39
+ CenterCrop(self.image_size),
40
+ ConvertImageDtype(torch.float),
41
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
42
+ )
43
+
44
+ def process_image(self, file):
45
+ logging.info("Loading image...")
46
+ image_data = file.read()
47
+ input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
48
+ loader = Compose([ToTensor()])
49
+ image = loader(input_image)
50
+ image = self.custom_transforms(image)
51
+ pixel_values = torch.stack([image]).permute(0, 2, 3, 1).numpy()
52
+ logging.info("Image loaded.")
53
+
54
+ return pixel_values
55
+
56
+ def generate_step(self, pixel_values, max_len, num_beams):
57
+ gen_kwargs = {"max_length": max_len , "num_beams": num_beams}
58
+
59
+ logging.info("Generating caption...")
60
+ output_ids = self.model.generate(pixel_values, **gen_kwargs)
61
+ token_ids = np.array(output_ids.sequences)[0]
62
+ caption = self.tokenizer.decode(token_ids)
63
+ logging.info("Caption generated.")
64
+
65
+ return caption
66
+
67
+ def get_caption(self, file, max_len, num_beams):
68
+ pixel_values = self.process_image(file)
69
+
70
+ generated_ids = self.generate_step(pixel_values, max_len, num_beams)
71
+ return generated_ids
72
+
73
+ @st.cache(allow_output_mutation=True)
74
+ def load_caption_generator():
75
+ generator = CaptionGenerator()
76
+ generator.load()
77
+ return generator
78
+
79
+ def main():
80
+ st.set_page_config(page_title="Indonesian Image Captioning Demo", page_icon="🖼️")
81
+ generator = load_caption_generator()
82
+
83
+ st.title("Indonesian Image Captioning Demo")
84
+
85
+ st.markdown(
86
+ """Indonesian image captioning demo, trained on [CLIP](https://huggingface.co/transformers/model_doc/clip.html) and [Marian](https://huggingface.co/transformers/model_doc/marian.html). Part of the [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).
87
+ """
88
+ )
89
+
90
+ st.sidebar.subheader("Configurable parameters")
91
+
92
+ max_len = st.sidebar.number_input(
93
+ "Maximum length",
94
+ value=8,
95
+ help="The maximum length of the sequence (caption) to be generated."
96
+ )
97
+
98
+ num_beams = st.sidebar.number_input(
99
+ "Number of beams",
100
+ value=4,
101
+ help="Number of beams for beam search. 1 means no beam search."
102
+ )
103
+
104
+ input_image = st.file_uploader("Insert image")
105
+ if st.button("Run"):
106
+ with st.spinner(text="Getting results..."):
107
+ if input_image:
108
+ caption = generator.get_caption(file=input_image, max_len=max_len, num_beams=num_beams)
109
+ st.subheader("Result")
110
+ st.write(caption.replace("<pad>", ""))
111
+ st.text("English translation")
112
+ st.write(translate(caption, "en", "id").replace("<pad>", ""))
113
+ else:
114
+ st.write("Please upload an image.")
115
+
116
+ if __name__ == '__main__':
117
+ main()
flax_clip_vision_marian/__pycache__/configuration_clip_vision_marian.cpython-37.pyc ADDED
Binary file (1.65 kB). View file
 
flax_clip_vision_marian/__pycache__/configuration_clip_vision_marian.cpython-38.pyc ADDED
Binary file (1.74 kB). View file
 
flax_clip_vision_marian/__pycache__/generation_clip_vision_utils.cpython-37.pyc ADDED
Binary file (21.4 kB). View file
 
flax_clip_vision_marian/__pycache__/generation_clip_vision_utils.cpython-38.pyc ADDED
Binary file (21.9 kB). View file
 
flax_clip_vision_marian/__pycache__/modeling_clip_vision_marian.cpython-37.pyc ADDED
Binary file (15.2 kB). View file
 
flax_clip_vision_marian/__pycache__/modeling_clip_vision_marian.cpython-38.pyc ADDED
Binary file (15.6 kB). View file
 
flax_clip_vision_marian/__pycache__/modeling_clip_vision_utils.cpython-37.pyc ADDED
Binary file (16.4 kB). View file
 
flax_clip_vision_marian/__pycache__/modeling_clip_vision_utils.cpython-38.pyc ADDED
Binary file (16.7 kB). View file
 
flax_clip_vision_marian/configuration_clip_vision_marian.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from transformers import CLIPVisionConfig, MarianConfig
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.utils import logging
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class CLIPVisionMarianConfig(PretrainedConfig):
11
+
12
+ model_type = "clip-vision-marian"
13
+ is_composition = True
14
+
15
+ def __init__(self, **kwargs):
16
+ super().__init__(**kwargs)
17
+
18
+ if "marian_config" not in kwargs:
19
+ raise ValueError("`marian_config_dict` can not be `None`.")
20
+
21
+ if "clip_vision_config" not in kwargs:
22
+ raise ValueError("`clip_vision_config_dict` can not be `None`.")
23
+
24
+ marian_config = kwargs.pop("marian_config")
25
+ clip_vision_config = kwargs.pop("clip_vision_config")
26
+
27
+ self.marian_config = MarianConfig(**marian_config)
28
+
29
+ self.clip_vision_config = CLIPVisionConfig(**clip_vision_config)
30
+
31
+ self.is_encoder_decoder = True
32
+
33
+ @classmethod
34
+ def from_clip_vision_marian_configs(
35
+ cls,
36
+ clip_vision_config: PretrainedConfig,
37
+ marian_config: PretrainedConfig,
38
+ **kwargs
39
+ ):
40
+ return cls(
41
+ clip_vision_config=clip_vision_config.to_dict(),
42
+ marian_config=marian_config.to_dict(),
43
+ **kwargs
44
+ )
45
+
46
+ def to_dict(self):
47
+ output = copy.deepcopy(self.__dict__)
48
+ output["clip_vision_config"] = self.clip_vision_config.to_dict()
49
+ output["marian_config"] = self.marian_config.to_dict()
50
+ output["model_type"] = self.__class__.model_type
51
+ return output
flax_clip_vision_marian/generation_clip_vision_utils.py ADDED
@@ -0,0 +1,990 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+
3
+ import flax
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import jaxlib.xla_extension as jax_xla
7
+ import numpy as np
8
+ from jax import lax
9
+ from transformers.file_utils import ModelOutput
10
+ from transformers.generation_flax_logits_process import (
11
+ FlaxForcedBOSTokenLogitsProcessor,
12
+ FlaxForcedEOSTokenLogitsProcessor,
13
+ FlaxLogitsProcessorList,
14
+ FlaxMinLengthLogitsProcessor,
15
+ FlaxTemperatureLogitsWarper,
16
+ FlaxTopKLogitsWarper,
17
+ FlaxTopPLogitsWarper,
18
+ )
19
+ from transformers.utils import logging
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ @flax.struct.dataclass
25
+ class FlaxGreedySearchOutput(ModelOutput):
26
+ """
27
+ Flax Base class for outputs of decoder-only generation models using greedy search.
28
+ Args:
29
+ sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
30
+ The generated sequences.
31
+ """
32
+
33
+ sequences: jax_xla.DeviceArray = None
34
+
35
+
36
+ @flax.struct.dataclass
37
+ class FlaxSampleOutput(ModelOutput):
38
+ """
39
+ Flax Base class for outputs of decoder-only generation models using sampling.
40
+ Args:
41
+ sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
42
+ The generated sequences.
43
+ """
44
+
45
+ sequences: jax_xla.DeviceArray = None
46
+
47
+
48
+ @flax.struct.dataclass
49
+ class FlaxBeamSearchOutput(ModelOutput):
50
+ """
51
+ Flax Base class for outputs of decoder-only generation models using greedy search.
52
+ Args:
53
+ sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
54
+ The generated sequences.
55
+ scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size,)`):
56
+ The scores (log probabilites) of the generated sequences.
57
+ """
58
+
59
+ sequences: jax_xla.DeviceArray = None
60
+ scores: jax_xla.DeviceArray = None
61
+
62
+
63
+ @flax.struct.dataclass
64
+ class GreedyState:
65
+ cur_len: jax_xla.DeviceArray
66
+ sequences: jax_xla.DeviceArray
67
+ running_token: jax_xla.DeviceArray
68
+ is_sent_finished: jax_xla.DeviceArray
69
+ model_kwargs: Dict[str, jax_xla.DeviceArray]
70
+
71
+
72
+ @flax.struct.dataclass
73
+ class SampleState:
74
+ cur_len: jax_xla.DeviceArray
75
+ sequences: jax_xla.DeviceArray
76
+ running_token: jax_xla.DeviceArray
77
+ is_sent_finished: jax_xla.DeviceArray
78
+ prng_key: jax_xla.DeviceArray
79
+ model_kwargs: Dict[str, jax_xla.DeviceArray]
80
+
81
+
82
+ @flax.struct.dataclass
83
+ class BeamSearchState:
84
+ cur_len: jax_xla.DeviceArray
85
+ running_sequences: jax_xla.DeviceArray
86
+ running_scores: jax_xla.DeviceArray
87
+ sequences: jax_xla.DeviceArray
88
+ scores: jax_xla.DeviceArray
89
+ is_sent_finished: jax_xla.DeviceArray
90
+ model_kwargs: Dict[str, jax_xla.DeviceArray]
91
+
92
+
93
+ class FlaxCLIPVisionMarianGenerationMixin:
94
+ """
95
+ A class containing all of the functions supporting generation, to be used as a mixin in
96
+ :class:`~transformers.FlaxPreTrainedModel`.
97
+ """
98
+
99
+ @staticmethod
100
+ def _run_loop_in_debug(cond_fn, body_fn, init_state):
101
+ """
102
+ Run generation in untraced mode. This should only be used for debugging purposes.
103
+ """
104
+ state = init_state
105
+ while cond_fn(state):
106
+ state = body_fn(state)
107
+ return state
108
+
109
+ def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, model_kwargs):
110
+ encoder_kwargs = {
111
+ argument: value
112
+ for argument, value in model_kwargs.items()
113
+ if not (
114
+ argument.startswith("decoder_") or argument.startswith("cross_attn")
115
+ )
116
+ }
117
+ model_kwargs["encoder_outputs"] = self.encode(
118
+ input_ids, return_dict=True, **encoder_kwargs
119
+ )
120
+ return model_kwargs
121
+
122
+ @staticmethod
123
+ def _expand_to_num_beams(tensor, num_beams):
124
+ return jnp.broadcast_to(
125
+ tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:]
126
+ )
127
+
128
+ def generate(
129
+ self,
130
+ input_ids: jax_xla.DeviceArray,
131
+ max_length: Optional[int] = None,
132
+ pad_token_id: Optional[int] = None,
133
+ bos_token_id: Optional[int] = None,
134
+ eos_token_id: Optional[int] = None,
135
+ decoder_start_token_id: Optional[int] = None,
136
+ do_sample: Optional[bool] = None,
137
+ prng_key: Optional[jax_xla.DeviceArray] = None,
138
+ top_k: Optional[int] = None,
139
+ top_p: Optional[float] = None,
140
+ temperature: Optional[float] = None,
141
+ num_beams: Optional[int] = None,
142
+ no_repeat_ngram_size: Optional[int] = None,
143
+ min_length: Optional[int] = None,
144
+ forced_bos_token_id: Optional[int] = None,
145
+ forced_eos_token_id: Optional[int] = None,
146
+ length_penalty: Optional[float] = None,
147
+ early_stopping: Optional[bool] = None,
148
+ trace: bool = True,
149
+ params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
150
+ **model_kwargs,
151
+ ):
152
+ r"""
153
+ Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
154
+ and, multinomial sampling.
155
+ Apart from :obj:`input_ids`, all the arguments below will default to the value of the attribute of the same
156
+ name inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the
157
+ default values of those config.
158
+ Most of these parameters are explained in more detail in `this blog post
159
+ <https://huggingface.co/blog/how-to-generate>`__.
160
+ Parameters:
161
+ input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
162
+ The sequence used as a prompt for the generation.
163
+ max_length (:obj:`int`, `optional`, defaults to 20):
164
+ The maximum length of the sequence to be generated.
165
+ do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
166
+ Whether or not to use sampling ; use greedy decoding otherwise.
167
+ temperature (:obj:`float`, `optional`, defaults to 1.0):
168
+ The value used to module the next token probabilities.
169
+ top_k (:obj:`int`, `optional`, defaults to 50):
170
+ The number of highest probability vocabulary tokens to keep for top-k-filtering.
171
+ top_p (:obj:`float`, `optional`, defaults to 1.0):
172
+ If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or
173
+ higher are kept for generation.
174
+ pad_token_id (:obj:`int`, `optional`):
175
+ The id of the `padding` token.
176
+ bos_token_id (:obj:`int`, `optional`):
177
+ The id of the `beginning-of-sequence` token.
178
+ eos_token_id (:obj:`int`, `optional`):
179
+ The id of the `end-of-sequence` token.
180
+ num_beams (:obj:`int`, `optional`, defaults to 1):
181
+ Number of beams for beam search. 1 means no beam search.
182
+ decoder_start_token_id (:obj:`int`, `optional`):
183
+ If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
184
+ trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
185
+ Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to
186
+ a considerably slower runtime.
187
+ params (:obj:`Dict[str, jax_xla.DeviceArray]`, `optional`):
188
+ Optionally the model parameters can be passed. Can be useful for parallelized generation.
189
+ model_kwargs:
190
+ Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
191
+ Return:
192
+ :class:`~transformers.file_utils.ModelOutput`.
193
+ Examples::
194
+ >>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
195
+ >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
196
+ >>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
197
+ >>> input_context = "The dog"
198
+ >>> # encode input context
199
+ >>> input_ids = tokenizer(input_context, return_tensors="jax").input_ids
200
+ >>> # generate candidates using sampling
201
+ >>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
202
+ >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
203
+ """
204
+ # set init values
205
+ max_length = (
206
+ max_length
207
+ if max_length is not None
208
+ else self.config.marian_config.max_length
209
+ )
210
+ bos_token_id = (
211
+ bos_token_id
212
+ if bos_token_id is not None
213
+ else self.config.marian_config.bos_token_id
214
+ )
215
+ pad_token_id = (
216
+ pad_token_id
217
+ if pad_token_id is not None
218
+ else self.config.marian_config.pad_token_id
219
+ )
220
+ eos_token_id = (
221
+ eos_token_id
222
+ if eos_token_id is not None
223
+ else self.config.marian_config.eos_token_id
224
+ )
225
+ decoder_start_token_id = (
226
+ decoder_start_token_id
227
+ if decoder_start_token_id
228
+ else self.config.marian_config.decoder_start_token_id
229
+ )
230
+ prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
231
+
232
+ if decoder_start_token_id is None and self.config.is_encoder_decoder:
233
+ raise ValueError(
234
+ "`decoder_start_token_id` has to be defined for encoder-decoder generation."
235
+ )
236
+
237
+ if self.config.is_encoder_decoder:
238
+ # add encoder_outputs to model_kwargs
239
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
240
+ input_ids, model_kwargs
241
+ )
242
+ # prepare decoder_input_ids for generation
243
+ input_ids = (
244
+ jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
245
+ )
246
+
247
+ do_sample = (
248
+ do_sample if do_sample is not None else self.config.marian_config.do_sample
249
+ )
250
+ num_beams = (
251
+ num_beams if num_beams is not None else self.config.marian_config.num_beams
252
+ )
253
+
254
+ if not do_sample and num_beams == 1:
255
+ logits_processor = self._get_logits_processor(
256
+ no_repeat_ngram_size,
257
+ min_length,
258
+ max_length,
259
+ eos_token_id,
260
+ forced_bos_token_id,
261
+ forced_eos_token_id,
262
+ )
263
+ return self._greedy_search(
264
+ input_ids,
265
+ max_length,
266
+ pad_token_id,
267
+ eos_token_id,
268
+ logits_processor=logits_processor,
269
+ trace=trace,
270
+ params=params,
271
+ model_kwargs=model_kwargs,
272
+ )
273
+ elif do_sample and num_beams == 1:
274
+ logits_warper = self._get_logits_warper(
275
+ top_k=top_k, top_p=top_p, temperature=temperature
276
+ )
277
+ logits_processor = self._get_logits_processor(
278
+ no_repeat_ngram_size,
279
+ min_length,
280
+ max_length,
281
+ eos_token_id,
282
+ forced_bos_token_id,
283
+ forced_eos_token_id,
284
+ )
285
+ return self._sample(
286
+ input_ids,
287
+ max_length,
288
+ pad_token_id,
289
+ eos_token_id,
290
+ prng_key,
291
+ logits_warper=logits_warper,
292
+ logits_processor=logits_processor,
293
+ trace=trace,
294
+ params=params,
295
+ model_kwargs=model_kwargs,
296
+ )
297
+ elif not do_sample and num_beams > 1:
298
+ # broadcast input_ids & encoder_outputs
299
+ input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
300
+
301
+ if "encoder_outputs" in model_kwargs:
302
+ model_kwargs["encoder_outputs"][
303
+ "last_hidden_state"
304
+ ] = self._expand_to_num_beams(
305
+ model_kwargs["encoder_outputs"]["last_hidden_state"],
306
+ num_beams=num_beams,
307
+ )
308
+
309
+ if "attention_mask" in model_kwargs:
310
+ model_kwargs["attention_mask"] = self._expand_to_num_beams(
311
+ model_kwargs["attention_mask"], num_beams=num_beams
312
+ )
313
+
314
+ logits_processor = self._get_logits_processor(
315
+ no_repeat_ngram_size,
316
+ min_length,
317
+ max_length,
318
+ eos_token_id,
319
+ forced_bos_token_id,
320
+ forced_eos_token_id,
321
+ )
322
+
323
+ return self._beam_search(
324
+ input_ids,
325
+ max_length,
326
+ pad_token_id,
327
+ eos_token_id,
328
+ length_penalty=length_penalty,
329
+ early_stopping=early_stopping,
330
+ logits_processor=logits_processor,
331
+ trace=trace,
332
+ params=params,
333
+ model_kwargs=model_kwargs,
334
+ )
335
+ else:
336
+ raise NotImplementedError("`Beam sampling is currently not implemented.")
337
+
338
+ def _get_logits_warper(
339
+ self, top_k: int = None, top_p: float = None, temperature: float = None
340
+ ) -> FlaxLogitsProcessorList:
341
+ """
342
+ This class returns a :obj:`~transformers.FlaxLogitsProcessorList` list object that contains all relevant
343
+ :obj:`~transformers.FlaxLogitsWarper` instances used for multinomial sampling.
344
+ """
345
+
346
+ # init warp parameters
347
+ top_k = top_k if top_k is not None else self.config.marian_config.top_k
348
+ top_p = top_p if top_p is not None else self.config.marian_config.top_p
349
+ temperature = (
350
+ temperature
351
+ if temperature is not None
352
+ else self.config.marian_config.temperature
353
+ )
354
+ # instantiate warpers list
355
+ warpers = FlaxLogitsProcessorList()
356
+
357
+ # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
358
+ # all samplers can be found in `generation_utils_samplers.py`
359
+ if temperature is not None and temperature != 1.0:
360
+ warpers.append(FlaxTemperatureLogitsWarper(temperature))
361
+ if top_k is not None and top_k != 0:
362
+ warpers.append(FlaxTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1))
363
+ if top_p is not None and top_p < 1.0:
364
+ warpers.append(FlaxTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1))
365
+
366
+ return warpers
367
+
368
+ def _get_logits_processor(
369
+ self,
370
+ no_repeat_ngram_size: int,
371
+ min_length: int,
372
+ max_length: int,
373
+ eos_token_id: int,
374
+ forced_bos_token_id: int,
375
+ forced_eos_token_id: int,
376
+ ) -> FlaxLogitsProcessorList:
377
+ """
378
+ This class returns a :obj:`~transformers.FlaxLogitsProcessorList` list object that contains all relevant
379
+ :obj:`~transformers.FlaxLogitsProcessor` instances used to modify the scores of the language model head.
380
+ """
381
+ processors = FlaxLogitsProcessorList()
382
+
383
+ # init warp parameters
384
+ no_repeat_ngram_size = (
385
+ no_repeat_ngram_size
386
+ if no_repeat_ngram_size is not None
387
+ else self.config.marian_config.no_repeat_ngram_size
388
+ )
389
+ min_length = (
390
+ min_length
391
+ if min_length is not None
392
+ else self.config.marian_config.min_length
393
+ )
394
+ eos_token_id = (
395
+ eos_token_id
396
+ if eos_token_id is not None
397
+ else self.config.marian_config.eos_token_id
398
+ )
399
+ forced_bos_token_id = (
400
+ forced_bos_token_id
401
+ if forced_bos_token_id is not None
402
+ else self.config.marian_config.forced_bos_token_id
403
+ )
404
+ forced_eos_token_id = (
405
+ forced_eos_token_id
406
+ if forced_eos_token_id is not None
407
+ else self.config.marian_config.forced_eos_token_id
408
+ )
409
+
410
+ # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
411
+ # all samplers can be found in `generation_utils_samplers.py`
412
+ if min_length is not None and eos_token_id is not None and min_length > -1:
413
+ processors.append(FlaxMinLengthLogitsProcessor(min_length, eos_token_id))
414
+ if forced_bos_token_id is not None:
415
+ processors.append(FlaxForcedBOSTokenLogitsProcessor(forced_bos_token_id))
416
+ if forced_eos_token_id is not None:
417
+ processors.append(
418
+ FlaxForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)
419
+ )
420
+ return processors
421
+
422
+ def _greedy_search(
423
+ self,
424
+ input_ids: None,
425
+ max_length: Optional[int] = None,
426
+ pad_token_id: Optional[int] = None,
427
+ eos_token_id: Optional[int] = None,
428
+ logits_processor: Optional[FlaxLogitsProcessorList] = None,
429
+ trace: bool = True,
430
+ params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
431
+ model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
432
+ ):
433
+ # init values
434
+ max_length = (
435
+ max_length
436
+ if max_length is not None
437
+ else self.config.marian_config.max_length
438
+ )
439
+ pad_token_id = (
440
+ pad_token_id
441
+ if pad_token_id is not None
442
+ else self.config.marian_config.pad_token_id
443
+ )
444
+ eos_token_id = (
445
+ eos_token_id
446
+ if eos_token_id is not None
447
+ else self.config.marian_config.eos_token_id
448
+ )
449
+
450
+ batch_size, cur_len = input_ids.shape
451
+
452
+ eos_token_id = jnp.array(eos_token_id)
453
+ pad_token_id = jnp.array(pad_token_id)
454
+ cur_len = jnp.array(cur_len)
455
+
456
+ # per batch-item holding current token in loop.
457
+ sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
458
+ sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
459
+
460
+ # per batch-item state bit indicating if sentence has finished.
461
+ is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
462
+
463
+ # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
464
+ # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
465
+ model = self.decode if self.config.is_encoder_decoder else self
466
+ # initialize model specific kwargs
467
+ model_kwargs = self.prepare_inputs_for_generation(
468
+ input_ids, max_length, **model_kwargs
469
+ )
470
+
471
+ # initialize state
472
+ state = GreedyState(
473
+ cur_len=cur_len,
474
+ sequences=sequences,
475
+ running_token=input_ids,
476
+ is_sent_finished=is_sent_finished,
477
+ model_kwargs=model_kwargs,
478
+ )
479
+
480
+ def greedy_search_cond_fn(state):
481
+ """state termination condition fn."""
482
+ has_reached_max_length = state.cur_len == max_length
483
+ all_sequence_finished = jnp.all(state.is_sent_finished)
484
+ finish_generation = jnp.logical_or(
485
+ has_reached_max_length, all_sequence_finished
486
+ )
487
+ return ~finish_generation
488
+
489
+ def greedy_search_body_fn(state):
490
+ """state update fn."""
491
+ model_outputs = model(
492
+ state.running_token, params=params, **state.model_kwargs
493
+ )
494
+ logits = model_outputs.logits[:, -1]
495
+
496
+ # apply min_length, ...
497
+ logits = logits_processor(state.sequences, logits, state.cur_len)
498
+
499
+ next_token = jnp.argmax(logits, axis=-1)
500
+
501
+ next_is_sent_finished = state.is_sent_finished | (
502
+ next_token == eos_token_id
503
+ )
504
+ next_token = (
505
+ next_token * ~next_is_sent_finished
506
+ + pad_token_id * next_is_sent_finished
507
+ )
508
+ next_token = next_token[:, None]
509
+
510
+ next_sequences = lax.dynamic_update_slice(
511
+ state.sequences, next_token, (0, state.cur_len)
512
+ )
513
+ next_model_kwargs = self.update_inputs_for_generation(
514
+ model_outputs, state.model_kwargs
515
+ )
516
+ return GreedyState(
517
+ cur_len=state.cur_len + 1,
518
+ sequences=next_sequences,
519
+ running_token=next_token,
520
+ is_sent_finished=next_is_sent_finished,
521
+ model_kwargs=next_model_kwargs,
522
+ )
523
+
524
+ # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
525
+ if input_ids.shape[1] > 1:
526
+ state = greedy_search_body_fn(state)
527
+
528
+ if not trace:
529
+ state = self._run_loop_in_debug(
530
+ greedy_search_cond_fn, greedy_search_body_fn, state
531
+ )
532
+ else:
533
+ state = lax.while_loop(greedy_search_cond_fn, greedy_search_body_fn, state)
534
+
535
+ return FlaxGreedySearchOutput(sequences=state.sequences)
536
+
537
+ def _sample(
538
+ self,
539
+ input_ids: None,
540
+ max_length: Optional[int] = None,
541
+ pad_token_id: Optional[int] = None,
542
+ eos_token_id: Optional[int] = None,
543
+ prng_key: Optional[jax_xla.DeviceArray] = None,
544
+ logits_processor: Optional[FlaxLogitsProcessorList] = None,
545
+ logits_warper: Optional[FlaxLogitsProcessorList] = None,
546
+ trace: bool = True,
547
+ params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
548
+ model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
549
+ ):
550
+ # init values
551
+ max_length = (
552
+ max_length
553
+ if max_length is not None
554
+ else self.config.marian_config.max_length
555
+ )
556
+ pad_token_id = (
557
+ pad_token_id
558
+ if pad_token_id is not None
559
+ else self.config.marian_config.pad_token_id
560
+ )
561
+ eos_token_id = (
562
+ eos_token_id
563
+ if eos_token_id is not None
564
+ else self.config.marian_config.eos_token_id
565
+ )
566
+ prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
567
+
568
+ batch_size, cur_len = input_ids.shape
569
+
570
+ eos_token_id = jnp.array(eos_token_id)
571
+ pad_token_id = jnp.array(pad_token_id)
572
+ cur_len = jnp.array(cur_len)
573
+
574
+ # per batch-item holding current token in loop.
575
+ sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
576
+ sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
577
+
578
+ # per batch-item state bit indicating if sentence has finished.
579
+ is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
580
+
581
+ # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
582
+ # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
583
+ model = self.decode if self.config.is_encoder_decoder else self
584
+
585
+ # initialize model specific kwargs
586
+ model_kwargs = self.prepare_inputs_for_generation(
587
+ input_ids, max_length, **model_kwargs
588
+ )
589
+
590
+ # initialize state
591
+ state = SampleState(
592
+ cur_len=cur_len,
593
+ sequences=sequences,
594
+ running_token=input_ids,
595
+ is_sent_finished=is_sent_finished,
596
+ prng_key=prng_key,
597
+ model_kwargs=model_kwargs,
598
+ )
599
+
600
+ def sample_search_cond_fn(state):
601
+ """state termination condition fn."""
602
+ has_reached_max_length = state.cur_len == max_length
603
+ all_sequence_finished = jnp.all(state.is_sent_finished)
604
+ finish_generation = jnp.logical_or(
605
+ has_reached_max_length, all_sequence_finished
606
+ )
607
+ return ~finish_generation
608
+
609
+ def sample_search_body_fn(state):
610
+ """state update fn."""
611
+ prng_key, prng_key_next = jax.random.split(state.prng_key)
612
+ model_outputs = model(
613
+ state.running_token, params=params, **state.model_kwargs
614
+ )
615
+
616
+ logits = model_outputs.logits[:, -1]
617
+
618
+ # apply min_length, ...
619
+ logits = logits_processor(state.sequences, logits, state.cur_len)
620
+ # apply top_k, top_k, temperature
621
+ logits = logits_warper(logits, logits, state.cur_len)
622
+
623
+ next_token = jax.random.categorical(
624
+ prng_key, model_outputs.logits[:, -1], axis=-1
625
+ )
626
+
627
+ next_is_sent_finished = state.is_sent_finished | (
628
+ next_token == eos_token_id
629
+ )
630
+ next_token = (
631
+ next_token * ~next_is_sent_finished
632
+ + pad_token_id * next_is_sent_finished
633
+ )
634
+ next_token = next_token[:, None]
635
+
636
+ next_sequences = lax.dynamic_update_slice(
637
+ state.sequences, next_token, (0, state.cur_len)
638
+ )
639
+ next_model_kwargs = self.update_inputs_for_generation(
640
+ model_outputs, state.model_kwargs
641
+ )
642
+
643
+ return SampleState(
644
+ cur_len=state.cur_len + 1,
645
+ sequences=next_sequences,
646
+ running_token=next_token,
647
+ is_sent_finished=next_is_sent_finished,
648
+ model_kwargs=next_model_kwargs,
649
+ prng_key=prng_key_next,
650
+ )
651
+
652
+ # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
653
+ if input_ids.shape[1] > 1:
654
+ state = sample_search_body_fn(state)
655
+
656
+ if not trace:
657
+ state = self._run_loop_in_debug(
658
+ sample_search_cond_fn, sample_search_body_fn, state
659
+ )
660
+ else:
661
+ state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
662
+
663
+ return FlaxSampleOutput(sequences=state.sequences)
664
+
665
+ def _beam_search(
666
+ self,
667
+ input_ids: None,
668
+ max_length: Optional[int] = None,
669
+ pad_token_id: Optional[int] = None,
670
+ eos_token_id: Optional[int] = None,
671
+ length_penalty: Optional[float] = None,
672
+ early_stopping: Optional[bool] = None,
673
+ logits_processor: Optional[FlaxLogitsProcessorList] = None,
674
+ trace: bool = True,
675
+ params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
676
+ model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
677
+ ):
678
+ """
679
+ This beam search function is heavily inspired by Flax's official example:
680
+ https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
681
+ """
682
+
683
+ def flatten_beam_dim(tensor):
684
+ """Flattens the first two dimensions of a non-scalar array."""
685
+ # ignore scalars (e.g. cache index)
686
+ if tensor.ndim == 0:
687
+ return tensor
688
+ return tensor.reshape(
689
+ (tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:]
690
+ )
691
+
692
+ def unflatten_beam_dim(tensor, batch_size, num_beams):
693
+ """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
694
+ # ignore scalars (e.g. cache index)
695
+ if tensor.ndim == 0:
696
+ return tensor
697
+ return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])
698
+
699
+ def gather_beams(nested, beam_indices, batch_size, new_num_beams):
700
+ """
701
+ Gathers the beam slices indexed by beam_indices into new beam array.
702
+ """
703
+ batch_indices = jnp.reshape(
704
+ jnp.arange(batch_size * new_num_beams) // new_num_beams,
705
+ (batch_size, new_num_beams),
706
+ )
707
+
708
+ def gather_fn(tensor):
709
+ # ignore scalars (e.g. cache index)
710
+ if tensor.ndim == 0:
711
+ return tensor
712
+ else:
713
+ return tensor[batch_indices, beam_indices]
714
+
715
+ return jax.tree_map(gather_fn, nested)
716
+
717
+ # init values
718
+ max_length = (
719
+ max_length
720
+ if max_length is not None
721
+ else self.config.marian_config.max_length
722
+ )
723
+ pad_token_id = (
724
+ pad_token_id
725
+ if pad_token_id is not None
726
+ else self.config.marian_config.pad_token_id
727
+ )
728
+ eos_token_id = (
729
+ eos_token_id
730
+ if eos_token_id is not None
731
+ else self.config.marian_config.eos_token_id
732
+ )
733
+ length_penalty = (
734
+ length_penalty
735
+ if length_penalty is not None
736
+ else self.config.marian_config.length_penalty
737
+ )
738
+ early_stopping = (
739
+ early_stopping
740
+ if early_stopping is not None
741
+ else self.config.marian_config.early_stopping
742
+ )
743
+
744
+ batch_size, num_beams, cur_len = input_ids.shape
745
+
746
+ eos_token_id = jnp.array(eos_token_id)
747
+ pad_token_id = jnp.array(pad_token_id)
748
+ cur_len = jnp.array(cur_len)
749
+
750
+ # per batch,beam-item holding current token in loop.
751
+ sequences = jnp.full(
752
+ (batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32
753
+ )
754
+ running_sequences = jnp.full(
755
+ (batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32
756
+ )
757
+ running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))
758
+
759
+ # per batch,beam-item state bit indicating if sentence has finished.
760
+ is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)
761
+
762
+ # per batch,beam-item score, logprobs
763
+ running_scores = jnp.tile(
764
+ jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1]
765
+ )
766
+ scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)
767
+
768
+ # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
769
+ # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
770
+ model = self.decode if self.config.is_encoder_decoder else self
771
+
772
+ # flatten beam dim
773
+ if "encoder_outputs" in model_kwargs:
774
+ model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
775
+ model_kwargs["encoder_outputs"]["last_hidden_state"]
776
+ )
777
+ if "attention_mask" in model_kwargs:
778
+ model_kwargs["attention_mask"] = flatten_beam_dim(
779
+ model_kwargs["attention_mask"]
780
+ )
781
+
782
+ # initialize model specific kwargs
783
+ model_kwargs = self.prepare_inputs_for_generation(
784
+ flatten_beam_dim(input_ids), max_length, **model_kwargs
785
+ )
786
+
787
+ # initialize state
788
+ state = BeamSearchState(
789
+ cur_len=cur_len,
790
+ running_sequences=running_sequences,
791
+ running_scores=running_scores,
792
+ sequences=sequences,
793
+ scores=scores,
794
+ is_sent_finished=is_sent_finished,
795
+ model_kwargs=model_kwargs,
796
+ )
797
+
798
+ def beam_search_cond_fn(state):
799
+ """beam search state termination condition fn."""
800
+
801
+ # 1. is less than max length?
802
+ not_max_length_yet = state.cur_len < max_length
803
+
804
+ # 2. can the new beams still improve?
805
+ best_running_score = state.running_scores[:, -1:] / (
806
+ max_length ** length_penalty
807
+ )
808
+ worst_finished_score = jnp.where(
809
+ state.is_sent_finished,
810
+ jnp.min(state.scores, axis=1, keepdims=True),
811
+ np.array(-1.0e7),
812
+ )
813
+ improvement_still_possible = jnp.all(
814
+ worst_finished_score < best_running_score
815
+ )
816
+
817
+ # 3. is there still a beam that has not finished?
818
+ still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
819
+
820
+ return not_max_length_yet & still_open_beam & improvement_still_possible
821
+
822
+ def beam_search_body_fn(state):
823
+ """beam search state update fn."""
824
+ # 1. Forward current tokens
825
+ # Collect the current position slice along length to feed the fast
826
+ # autoregressive decoder model. Flatten the beam dimension into batch
827
+ # dimension for feeding into the model.
828
+ # unflatten beam dimension
829
+ # Unflatten beam dimension in attention cache arrays
830
+ input_token = flatten_beam_dim(
831
+ lax.dynamic_slice(
832
+ state.running_sequences,
833
+ (0, 0, state.cur_len - 1),
834
+ (batch_size, num_beams, 1),
835
+ )
836
+ )
837
+ model_outputs = model(input_token, params=params, **state.model_kwargs)
838
+ logits = unflatten_beam_dim(
839
+ model_outputs.logits[:, 0], batch_size, num_beams
840
+ )
841
+ cache = jax.tree_map(
842
+ lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams),
843
+ model_outputs.past_key_values,
844
+ )
845
+
846
+ # 2. Compute log probs
847
+ # get log probabilities from logits,
848
+ # process logits with processors (*e.g.* min_length, ...), and
849
+ # add new logprobs to existing running logprobs scores.
850
+ log_probs = jax.nn.log_softmax(logits)
851
+ log_probs = logits_processor(
852
+ flatten_beam_dim(running_sequences),
853
+ flatten_beam_dim(log_probs),
854
+ state.cur_len,
855
+ )
856
+ log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
857
+ log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
858
+ vocab_size = log_probs.shape[2]
859
+ log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))
860
+
861
+ # 3. Retrieve top-K
862
+ # Each item in batch has num_beams * vocab_size candidate sequences.
863
+ # For each item, get the top 2*k candidates with the highest log-
864
+ # probabilities. We gather the top 2*K beams here so that even if the best
865
+ # K sequences reach EOS simultaneously, we have another K sequences
866
+ # remaining to continue the live beam search.
867
+ # Gather the top 2*K scores from _all_ beams.
868
+ # Gather 2*k top beams.
869
+ # Recover the beam index by floor division.
870
+ # Recover token id by modulo division and expand Id array for broadcasting.
871
+ # Update sequences for the 2*K top-k new sequences.
872
+ beams_to_keep = 2 * num_beams
873
+ topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
874
+ topk_beam_indices = topk_indices // vocab_size
875
+ topk_running_sequences = gather_beams(
876
+ state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
877
+ )
878
+ topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
879
+ topk_sequences = lax.dynamic_update_slice(
880
+ topk_running_sequences, topk_ids, (0, 0, state.cur_len)
881
+ )
882
+
883
+ # 4. Check which sequences have ended
884
+ # Update current sequences:
885
+ # Did any of these sequences reach an end marker?
886
+ # To prevent these just finished sequences from being added to the current sequences
887
+ # set of active beam search sequences, set their log probs to a very large
888
+ # negative value.
889
+ did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
890
+ topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
891
+
892
+ # 5. Get running sequences scores for next
893
+ # Determine the top k beam indices (from top 2*k beams) from log probs
894
+ # and gather top k beams (from top 2*k beams).
895
+ next_topk_indices = jnp.flip(
896
+ lax.top_k(topk_log_probs, k=num_beams)[1], axis=1
897
+ )
898
+ next_running_sequences, next_running_scores = gather_beams(
899
+ [topk_sequences, topk_log_probs],
900
+ next_topk_indices,
901
+ batch_size,
902
+ num_beams,
903
+ )
904
+
905
+ # 6. Process topk logits
906
+ # Further process log probs:
907
+ # - add length penalty
908
+ # - make sure no scores can be added anymore if beam is full
909
+ # - make sure still running sequences cannot be chosen as finalized beam
910
+ topk_log_probs = topk_log_probs / (state.cur_len ** length_penalty)
911
+ beams_in_batch_are_full = (
912
+ jnp.broadcast_to(
913
+ state.is_sent_finished.all(axis=-1, keepdims=True),
914
+ did_topk_just_finished.shape,
915
+ )
916
+ & early_stopping
917
+ )
918
+ add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
919
+ topk_log_probs += add_penalty * np.array(-1.0e7)
920
+
921
+ # 7. Get scores, sequences, is sentence finished for next.
922
+ # Combine sequences, scores, and flags along the beam dimension and compare
923
+ # new finished sequence scores to existing finished scores and select the
924
+ # best from the new set of beams
925
+ merged_sequences = jnp.concatenate(
926
+ [state.sequences, topk_sequences], axis=1
927
+ )
928
+ merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
929
+ merged_is_sent_finished = jnp.concatenate(
930
+ [state.is_sent_finished, did_topk_just_finished], axis=1
931
+ )
932
+ topk_merged_indices = jnp.flip(
933
+ lax.top_k(merged_scores, k=num_beams)[1], axis=1
934
+ )
935
+ next_sequences, next_scores, next_is_sent_finished = gather_beams(
936
+ [merged_sequences, merged_scores, merged_is_sent_finished],
937
+ topk_merged_indices,
938
+ batch_size,
939
+ num_beams,
940
+ )
941
+
942
+ # 8. Update model kwargs.
943
+ # Determine the top k beam indices from the original set of all beams.
944
+ # With these, gather the top k beam-associated caches.
945
+ next_running_indices = gather_beams(
946
+ topk_beam_indices, next_topk_indices, batch_size, num_beams
947
+ )
948
+ next_cache = gather_beams(
949
+ cache, next_running_indices, batch_size, num_beams
950
+ )
951
+ model_outputs["past_key_values"] = jax.tree_map(
952
+ lambda x: flatten_beam_dim(x), next_cache
953
+ )
954
+ next_model_kwargs = self.update_inputs_for_generation(
955
+ model_outputs, state.model_kwargs
956
+ )
957
+
958
+ return BeamSearchState(
959
+ cur_len=state.cur_len + 1,
960
+ running_scores=next_running_scores,
961
+ running_sequences=next_running_sequences,
962
+ scores=next_scores,
963
+ sequences=next_sequences,
964
+ is_sent_finished=next_is_sent_finished,
965
+ model_kwargs=next_model_kwargs,
966
+ )
967
+
968
+ # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
969
+ state = beam_search_body_fn(state)
970
+
971
+ if not trace:
972
+ state = self._run_loop_in_debug(
973
+ beam_search_cond_fn, beam_search_body_fn, state
974
+ )
975
+ else:
976
+ state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)
977
+
978
+ # Account for the edge-case where there are no finished sequences for a
979
+ # particular batch item. If so, return running sequences for that batch item.
980
+ none_finished = jnp.any(state.is_sent_finished, axis=1)
981
+ sequences = jnp.where(
982
+ none_finished[:, None, None], state.sequences, state.running_sequences
983
+ )
984
+ scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
985
+
986
+ # take best beam for each batch
987
+ sequences = sequences[:, -1]
988
+ scores = scores[:, -1]
989
+
990
+ return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
flax_clip_vision_marian/modeling_clip_vision_marian.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple
2
+
3
+ import flax.linen as nn
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from flax.core.frozen_dict import FrozenDict, unfreeze
7
+ from jax import lax
8
+ from jax.random import PRNGKey
9
+ from transformers import (
10
+ CLIPVisionConfig,
11
+ FlaxCLIPVisionModel,
12
+ FlaxMarianModel,
13
+ MarianConfig
14
+ )
15
+ from transformers.modeling_flax_outputs import (
16
+ FlaxBaseModelOutputWithPooling,
17
+ FlaxCausalLMOutputWithCrossAttentions,
18
+ FlaxSeq2SeqLMOutput,
19
+ FlaxSeq2SeqModelOutput,
20
+ )
21
+ from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule
22
+ from transformers.models.marian.modeling_flax_marian import (
23
+ FlaxMarianDecoder,
24
+ FlaxPreTrainedModel,
25
+ shift_tokens_right
26
+ )
27
+
28
+ from .modeling_clip_vision_utils import FlaxCLIPVisionMarianPreTrainedModel
29
+ from .configuration_clip_vision_marian import CLIPVisionMarianConfig
30
+
31
+
32
+ class FlaxCLIPVisionMarianModule(nn.Module):
33
+ config: CLIPVisionMarianConfig
34
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
35
+
36
+ def setup(self):
37
+ self.shared = nn.Embed(
38
+ self.config.marian_config.vocab_size,
39
+ self.config.marian_config.d_model,
40
+ embedding_init=jax.nn.initializers.normal(
41
+ self.config.marian_config.init_std, self.dtype
42
+ ),
43
+ dtype=self.dtype,
44
+ )
45
+
46
+ self.encoder = FlaxCLIPVisionModule(
47
+ self.config.clip_vision_config, dtype=self.dtype
48
+ )
49
+ self.decoder = FlaxMarianDecoder(
50
+ self.config.marian_config, dtype=self.dtype, embed_tokens=self.shared
51
+ )
52
+
53
+ self.visual_projection = nn.Dense(
54
+ self.config.marian_config.hidden_size,
55
+ dtype=self.dtype,
56
+ kernel_init=jax.nn.initializers.normal(
57
+ self.config.marian_config.init_std, self.dtype
58
+ ),
59
+ )
60
+
61
+ def _get_encoder_module(self):
62
+ return self.encoder
63
+
64
+ def _get_decoder_module(self):
65
+ return self.decoder
66
+
67
+ def __call__(
68
+ self,
69
+ pixel_values,
70
+ decoder_input_ids,
71
+ decoder_attention_mask,
72
+ decoder_position_ids,
73
+ output_attentions: bool = False,
74
+ output_hidden_states: bool = False,
75
+ return_dict: bool = True,
76
+ deterministic: bool = True,
77
+ ):
78
+
79
+ encoder_outputs = self.encoder(
80
+ pixel_values=pixel_values,
81
+ output_attentions=output_attentions,
82
+ output_hidden_states=output_hidden_states,
83
+ return_dict=return_dict,
84
+ deterministic=deterministic,
85
+ )
86
+
87
+ batch_size, sequence_length = encoder_outputs[0].shape[:2]
88
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
89
+
90
+ encoder_hidden_states = self.visual_projection(encoder_outputs[0])
91
+
92
+ decoder_outputs = self.decoder(
93
+ input_ids=decoder_input_ids,
94
+ attention_mask=decoder_attention_mask,
95
+ position_ids=decoder_position_ids,
96
+ encoder_hidden_states=encoder_hidden_states,
97
+ encoder_attention_mask=encoder_attention_mask,
98
+ output_attentions=output_attentions,
99
+ output_hidden_states=output_hidden_states,
100
+ return_dict=return_dict,
101
+ deterministic=deterministic,
102
+ )
103
+
104
+ if not return_dict:
105
+ return decoder_outputs + encoder_outputs
106
+
107
+ return FlaxSeq2SeqModelOutput(
108
+ last_hidden_state=decoder_outputs.last_hidden_state,
109
+ decoder_hidden_states=decoder_outputs.hidden_states,
110
+ decoder_attentions=decoder_outputs.attentions,
111
+ cross_attentions=decoder_outputs.cross_attentions,
112
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
113
+ encoder_hidden_states=encoder_outputs.hidden_states,
114
+ encoder_attentions=encoder_outputs.attentions,
115
+ )
116
+
117
+
118
+ class FlaxCLIPVisionMarianForConditionalGenerationModule(nn.Module):
119
+ config: CLIPVisionMarianConfig
120
+ dtype: jnp.dtype = jnp.float32
121
+ bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
122
+
123
+ def setup(self):
124
+ self.model = FlaxCLIPVisionMarianModule(config=self.config, dtype=self.dtype)
125
+ self.lm_head = nn.Dense(
126
+ self.model.shared.num_embeddings,
127
+ use_bias=False,
128
+ dtype=self.dtype,
129
+ kernel_init=jax.nn.initializers.normal(
130
+ self.config.marian_config.init_std, self.dtype
131
+ ),
132
+ )
133
+ self.final_logits_bias = self.param(
134
+ "final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings)
135
+ )
136
+
137
+ def _get_encoder_module(self):
138
+ return self.model.encoder
139
+
140
+ def _get_decoder_module(self):
141
+ return self.model.decoder
142
+
143
+ def _get_visual_projection_module(self):
144
+ return self.model.visual_projection
145
+
146
+ def __call__(
147
+ self,
148
+ pixel_values,
149
+ decoder_input_ids,
150
+ decoder_attention_mask,
151
+ decoder_position_ids,
152
+ output_attentions: bool = False,
153
+ output_hidden_states: bool = False,
154
+ return_dict: bool = True,
155
+ deterministic: bool = True,
156
+ ):
157
+ outputs = self.model(
158
+ pixel_values=pixel_values,
159
+ decoder_input_ids=decoder_input_ids,
160
+ decoder_attention_mask=decoder_attention_mask,
161
+ decoder_position_ids=decoder_position_ids,
162
+ output_attentions=output_attentions,
163
+ output_hidden_states=output_hidden_states,
164
+ return_dict=return_dict,
165
+ deterministic=deterministic,
166
+ )
167
+
168
+ hidden_states = outputs[0]
169
+
170
+ if self.config.tie_word_embeddings:
171
+ shared_embedding = self.model.variables["params"]["shared"]["embedding"]
172
+ lm_logits = self.lm_head.apply(
173
+ {"params": {"kernel": shared_embedding.T}}, hidden_states
174
+ )
175
+ else:
176
+ lm_logits = self.lm_head(hidden_states)
177
+
178
+ lm_logits += self.final_logits_bias
179
+
180
+ if not return_dict:
181
+ output = (lm_logits,) + outputs[1:]
182
+ return output
183
+
184
+ return FlaxSeq2SeqLMOutput(
185
+ logits=lm_logits,
186
+ decoder_hidden_states=outputs.decoder_hidden_states,
187
+ decoder_attentions=outputs.decoder_attentions,
188
+ cross_attentions=outputs.cross_attentions,
189
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
190
+ encoder_hidden_states=outputs.encoder_hidden_states,
191
+ encoder_attentions=outputs.encoder_attentions,
192
+ )
193
+
194
+
195
+ class FlaxCLIPVisionMarianOuterPreTrainedModel(FlaxCLIPVisionMarianPreTrainedModel):
196
+ config_class = CLIPVisionMarianConfig
197
+ base_model_prefix: str = "model"
198
+ module_class: nn.Module = None
199
+
200
+ def __init__(
201
+ self,
202
+ config: CLIPVisionMarianConfig,
203
+ input_shape: Tuple = None,
204
+ seed: int = 0,
205
+ dtype: jnp.dtype = jnp.float32,
206
+ **kwargs,
207
+ ):
208
+ if input_shape is None:
209
+ input_shape = (
210
+ (
211
+ 1,
212
+ config.clip_vision_config.image_size,
213
+ config.clip_vision_config.image_size,
214
+ 3,
215
+ ),
216
+ (1, 1),
217
+ )
218
+
219
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
220
+ super().__init__(
221
+ config, module, input_shape=input_shape, seed=seed, dtype=dtype
222
+ )
223
+
224
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
225
+ # init input tensors
226
+ pixel_values = jax.random.normal(rng, input_shape[0])
227
+ # # make sure initialization pass will work for FlaxMarianForSequenceClassificationModule
228
+ # input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
229
+
230
+ decoder_input_ids = jnp.zeros(input_shape[1], dtype="i4")
231
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
232
+
233
+ batch_size, sequence_length = decoder_input_ids.shape
234
+ decoder_position_ids = jnp.broadcast_to(
235
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
236
+ )
237
+
238
+ params_rng, dropout_rng = jax.random.split(rng)
239
+ rngs = {"params": params_rng, "dropout": dropout_rng}
240
+
241
+ return self.module.init(
242
+ rngs,
243
+ pixel_values,
244
+ decoder_input_ids,
245
+ decoder_attention_mask,
246
+ decoder_position_ids,
247
+ )["params"]
248
+
249
+ def init_cache(self, batch_size, max_length, encoder_outputs):
250
+
251
+ decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
252
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
253
+ decoder_position_ids = jnp.broadcast_to(
254
+ jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]),
255
+ decoder_input_ids.shape,
256
+ )
257
+
258
+ def _decoder_forward(
259
+ module,
260
+ decoder_input_ids,
261
+ decoder_attention_mask,
262
+ decoder_position_ids,
263
+ **kwargs,
264
+ ):
265
+ decoder_module = module._get_decoder_module()
266
+ return decoder_module(
267
+ decoder_input_ids,
268
+ decoder_attention_mask,
269
+ decoder_position_ids,
270
+ **kwargs,
271
+ )
272
+
273
+ init_variables = self.module.init(
274
+ jax.random.PRNGKey(0),
275
+ decoder_input_ids=decoder_input_ids,
276
+ decoder_attention_mask=decoder_attention_mask,
277
+ decoder_position_ids=decoder_position_ids,
278
+ encoder_hidden_states=encoder_outputs[0],
279
+ init_cache=True,
280
+ method=_decoder_forward, # we only need to call the decoder to init the cache
281
+ )
282
+ return unfreeze(init_variables["cache"])
283
+
284
+ def encode(
285
+ self,
286
+ pixel_values: jnp.ndarray,
287
+ output_attentions: Optional[bool] = None,
288
+ output_hidden_states: Optional[bool] = None,
289
+ return_dict: Optional[bool] = None,
290
+ train: bool = False,
291
+ params: dict = None,
292
+ dropout_rng: PRNGKey = None,
293
+ ):
294
+ output_attentions = (
295
+ output_attentions
296
+ if output_attentions is not None
297
+ else self.config.output_attentions
298
+ )
299
+ output_hidden_states = (
300
+ output_hidden_states
301
+ if output_hidden_states is not None
302
+ else self.config.output_hidden_states
303
+ )
304
+ return_dict = (
305
+ return_dict if return_dict is not None else self.config.return_dict
306
+ )
307
+
308
+ #pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
309
+
310
+ # Handle any PRNG if needed
311
+ rngs = {}
312
+ if dropout_rng is not None:
313
+ rngs["dropout"] = dropout_rng
314
+
315
+ def _encoder_forward(module, pixel_values, **kwargs):
316
+ encode_module = module._get_encoder_module()
317
+ visual_projection = module._get_visual_projection_module()
318
+
319
+ outputs = encode_module(pixel_values, **kwargs)
320
+
321
+ return FlaxBaseModelOutputWithPooling(
322
+ last_hidden_state=visual_projection(outputs.last_hidden_state),
323
+ pooler_output=outputs.pooler_output,
324
+ hidden_states=outputs.hidden_states,
325
+ attentions=outputs.attentions,
326
+ )
327
+
328
+ return self.module.apply(
329
+ {"params": params or self.params},
330
+ pixel_values=jnp.array(pixel_values, dtype="i4"),
331
+ output_attentions=output_attentions,
332
+ output_hidden_states=output_hidden_states,
333
+ return_dict=return_dict,
334
+ deterministic=not train,
335
+ rngs=rngs,
336
+ method=_encoder_forward,
337
+ )
338
+
339
+ def decode(
340
+ self,
341
+ decoder_input_ids,
342
+ encoder_outputs,
343
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
344
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
345
+ decoder_position_ids: Optional[jnp.ndarray] = None,
346
+ past_key_values: dict = None,
347
+ output_attentions: Optional[bool] = None,
348
+ output_hidden_states: Optional[bool] = None,
349
+ return_dict: Optional[bool] = None,
350
+ train: bool = False,
351
+ params: dict = None,
352
+ dropout_rng: PRNGKey = None,
353
+ ):
354
+
355
+ output_attentions = (
356
+ output_attentions
357
+ if output_attentions is not None
358
+ else self.config.output_attentions
359
+ )
360
+ output_hidden_states = (
361
+ output_hidden_states
362
+ if output_hidden_states is not None
363
+ else self.config.output_hidden_states
364
+ )
365
+ return_dict = (
366
+ return_dict if return_dict is not None else self.config.return_dict
367
+ )
368
+
369
+ encoder_hidden_states = encoder_outputs[0]
370
+
371
+ if encoder_attention_mask is None:
372
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
373
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
374
+
375
+ batch_size, sequence_length = decoder_input_ids.shape
376
+ if decoder_attention_mask is None:
377
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
378
+
379
+ if decoder_position_ids is None:
380
+ if past_key_values is not None:
381
+ raise ValueError(
382
+ "Make sure to provide `decoder_position_ids` when passing `past_key_values`."
383
+ )
384
+
385
+ decoder_position_ids = jnp.broadcast_to(
386
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
387
+ )
388
+
389
+ # Handle any PRNG if needed
390
+ rngs = {}
391
+ if dropout_rng is not None:
392
+ rngs["dropout"] = dropout_rng
393
+
394
+ inputs = {"params": params or self.params}
395
+
396
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
397
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
398
+ # it can be changed by FlaxMarianAttention module
399
+ if past_key_values:
400
+ inputs["cache"] = past_key_values
401
+ mutable = ["cache"]
402
+ else:
403
+ mutable = False
404
+
405
+ def _decoder_forward(
406
+ module,
407
+ decoder_input_ids,
408
+ decoder_attention_mask,
409
+ decoder_position_ids,
410
+ **kwargs,
411
+ ):
412
+ decoder_module = module._get_decoder_module()
413
+ return decoder_module(
414
+ decoder_input_ids,
415
+ decoder_attention_mask,
416
+ decoder_position_ids,
417
+ **kwargs,
418
+ )
419
+
420
+ outputs = self.module.apply(
421
+ inputs,
422
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
423
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
424
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
425
+ encoder_hidden_states=encoder_hidden_states,
426
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
427
+ output_attentions=output_attentions,
428
+ output_hidden_states=output_hidden_states,
429
+ return_dict=return_dict,
430
+ deterministic=not train,
431
+ rngs=rngs,
432
+ mutable=mutable,
433
+ method=_decoder_forward,
434
+ )
435
+
436
+ # add updated cache to model output
437
+ if past_key_values is not None and return_dict:
438
+ outputs, past = outputs
439
+ outputs["past_key_values"] = unfreeze(past["cache"])
440
+ return outputs
441
+ elif past_key_values is not None and not return_dict:
442
+ outputs, past = outputs
443
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
444
+
445
+ return outputs
446
+
447
+ def __call__(
448
+ self,
449
+ pixel_values: jnp.ndarray,
450
+ decoder_input_ids: Optional[jnp.ndarray] = None,
451
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
452
+ decoder_position_ids: Optional[jnp.ndarray] = None,
453
+ output_attentions: Optional[bool] = None,
454
+ output_hidden_states: Optional[bool] = None,
455
+ return_dict: Optional[bool] = None,
456
+ train: bool = False,
457
+ params: dict = None,
458
+ dropout_rng: PRNGKey = None,
459
+ ):
460
+ output_attentions = (
461
+ output_attentions
462
+ if output_attentions is not None
463
+ else self.config.output_attentions
464
+ )
465
+ output_hidden_states = (
466
+ output_hidden_states
467
+ if output_hidden_states is not None
468
+ else self.config.output_hidden_states
469
+ )
470
+ return_dict = (
471
+ return_dict if return_dict is not None else self.config.return_dict
472
+ )
473
+
474
+ #pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
475
+
476
+ # # prepare encoder inputs
477
+ # if attention_mask is None:
478
+ # attention_mask = jnp.ones_like(input_ids)
479
+ # if position_ids is None:
480
+ # batch_size, sequence_length = input_ids.shape
481
+ # position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
482
+
483
+ # prepare decoder inputs
484
+ # if decoder_input_ids is None:
485
+ # decoder_input_ids = shift_tokens_right(
486
+ # input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
487
+ # ) # TODO: Check how to use this
488
+ if decoder_attention_mask is None:
489
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
490
+ if decoder_position_ids is None:
491
+ batch_size, sequence_length = decoder_input_ids.shape
492
+ decoder_position_ids = jnp.broadcast_to(
493
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
494
+ )
495
+
496
+ # Handle any PRNG if needed
497
+ rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
498
+
499
+ return self.module.apply(
500
+ {"params": params or self.params},
501
+ pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
502
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
503
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
504
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
505
+ output_attentions=output_attentions,
506
+ output_hidden_states=output_hidden_states,
507
+ return_dict=return_dict,
508
+ deterministic=not train,
509
+ rngs=rngs,
510
+ )
511
+
512
+
513
+ class FlaxCLIPVisionMarianForConditionalGeneration(
514
+ FlaxCLIPVisionMarianOuterPreTrainedModel
515
+ ):
516
+ module_class = FlaxCLIPVisionMarianForConditionalGenerationModule
517
+ dtype: jnp.dtype = jnp.float32
518
+
519
+ def decode(
520
+ self,
521
+ decoder_input_ids,
522
+ encoder_outputs,
523
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
524
+ decoder_attention_mask: Optional[jnp.ndarray] = None,
525
+ decoder_position_ids: Optional[jnp.ndarray] = None,
526
+ past_key_values: dict = None,
527
+ output_attentions: Optional[bool] = None,
528
+ output_hidden_states: Optional[bool] = None,
529
+ return_dict: Optional[bool] = None,
530
+ deterministic: bool = True,
531
+ params: dict = None,
532
+ dropout_rng: PRNGKey = None,
533
+ ):
534
+ output_attentions = (
535
+ output_attentions
536
+ if output_attentions is not None
537
+ else self.config.output_attentions
538
+ )
539
+ output_hidden_states = (
540
+ output_hidden_states
541
+ if output_hidden_states is not None
542
+ else self.config.output_hidden_states
543
+ )
544
+ return_dict = (
545
+ return_dict if return_dict is not None else self.config.return_dict
546
+ )
547
+
548
+ encoder_hidden_states = encoder_outputs[0]
549
+
550
+ if encoder_attention_mask is None:
551
+ batch_size, sequence_length = encoder_hidden_states.shape[:2]
552
+ encoder_attention_mask = jnp.ones((batch_size, sequence_length))
553
+
554
+ batch_size, sequence_length = decoder_input_ids.shape
555
+ if decoder_attention_mask is None:
556
+ decoder_attention_mask = jnp.ones((batch_size, sequence_length))
557
+
558
+ if decoder_position_ids is None:
559
+ if past_key_values is not None:
560
+ raise ValueError(
561
+ "Make sure to provide `decoder_position_ids` when passing `past_key_values`."
562
+ )
563
+
564
+ decoder_position_ids = jnp.broadcast_to(
565
+ jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
566
+ )
567
+
568
+ # Handle any PRNG if needed
569
+ rngs = {}
570
+ if dropout_rng is not None:
571
+ rngs["dropout"] = dropout_rng
572
+
573
+ inputs = {"params": params or self.params}
574
+
575
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
576
+ # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
577
+ # it can be changed by FlaxMarianAttention module
578
+ if past_key_values:
579
+ inputs["cache"] = past_key_values
580
+ mutable = ["cache"]
581
+ else:
582
+ mutable = False
583
+
584
+ def _decoder_forward(
585
+ module,
586
+ decoder_input_ids,
587
+ decoder_attention_mask,
588
+ decoder_position_ids,
589
+ **kwargs,
590
+ ):
591
+ decoder_module = module._get_decoder_module()
592
+ outputs = decoder_module(
593
+ decoder_input_ids,
594
+ decoder_attention_mask,
595
+ decoder_position_ids,
596
+ **kwargs,
597
+ )
598
+ hidden_states = outputs[0]
599
+
600
+ if self.config.tie_word_embeddings:
601
+ shared_embedding = module.model.variables["params"]["shared"][
602
+ "embedding"
603
+ ]
604
+ lm_logits = module.lm_head.apply(
605
+ {"params": {"kernel": shared_embedding.T}}, hidden_states
606
+ )
607
+ else:
608
+ lm_logits = module.lm_head(hidden_states)
609
+
610
+ lm_logits += module.final_logits_bias
611
+ return lm_logits, outputs
612
+
613
+ outputs = self.module.apply(
614
+ inputs,
615
+ decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
616
+ decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
617
+ decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
618
+ encoder_hidden_states=encoder_hidden_states,
619
+ encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
620
+ output_attentions=output_attentions,
621
+ output_hidden_states=output_hidden_states,
622
+ return_dict=return_dict,
623
+ deterministic=deterministic,
624
+ rngs=rngs,
625
+ mutable=mutable,
626
+ method=_decoder_forward,
627
+ )
628
+
629
+ if past_key_values is None:
630
+ lm_logits, decoder_outputs = outputs
631
+ else:
632
+ (lm_logits, decoder_outputs), past = outputs
633
+
634
+ if return_dict:
635
+ outputs = FlaxCausalLMOutputWithCrossAttentions(
636
+ logits=lm_logits,
637
+ hidden_states=decoder_outputs.hidden_states,
638
+ attentions=decoder_outputs.attentions,
639
+ cross_attentions=decoder_outputs.cross_attentions,
640
+ )
641
+ else:
642
+ outputs = (lm_logits,) + decoder_outputs[1:]
643
+
644
+ # add updated cache to model output
645
+ if past_key_values is not None and return_dict:
646
+ outputs["past_key_values"] = unfreeze(past["cache"])
647
+ return outputs
648
+ elif past_key_values is not None and not return_dict:
649
+ outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
650
+
651
+ return outputs
652
+
653
+ def prepare_inputs_for_generation(
654
+ self,
655
+ decoder_input_ids,
656
+ max_length,
657
+ attention_mask: Optional[jnp.DeviceArray] = None,
658
+ decoder_attention_mask: Optional[jnp.DeviceArray] = None,
659
+ encoder_outputs=None,
660
+ **kwargs,
661
+ ):
662
+ # initializing the cache
663
+ batch_size, seq_length = decoder_input_ids.shape
664
+
665
+ past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
666
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
667
+ # But since the decoder uses a causal mask, those positions are masked anyways.
668
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
669
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
670
+ if decoder_attention_mask is not None:
671
+ position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
672
+ extended_attention_mask = lax.dynamic_update_slice(
673
+ extended_attention_mask, decoder_attention_mask, (0, 0)
674
+ )
675
+ else:
676
+ position_ids = jnp.broadcast_to(
677
+ jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
678
+ )
679
+
680
+ return {
681
+ "past_key_values": past_key_values,
682
+ "encoder_outputs": encoder_outputs,
683
+ "encoder_attention_mask": attention_mask,
684
+ "decoder_attention_mask": extended_attention_mask,
685
+ "decoder_position_ids": position_ids,
686
+ }
687
+
688
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
689
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
690
+ model_kwargs["decoder_position_ids"] = (
691
+ model_kwargs["decoder_position_ids"][:, -1:] + 1
692
+ )
693
+ return model_kwargs
694
+
695
+ @classmethod
696
+ def from_pretrained(cls, *args, **kwargs):
697
+ # At the moment fast initialization is not supported
698
+ # for composite models
699
+ # kwargs["_fast_init"] = False
700
+ return super().from_pretrained(*args, **kwargs)
701
+
702
+ @classmethod
703
+ def from_clip_vision_marian_pretrained(
704
+ cls,
705
+ clip_vision_model_name_or_path: str = None,
706
+ marian_model_name_or_path: str = None,
707
+ *model_args,
708
+ **kwargs,
709
+ ) -> FlaxCLIPVisionMarianPreTrainedModel:
710
+
711
+ kwargs_marian = {
712
+ argument[len("marian_") :]: value
713
+ for argument, value in kwargs.items()
714
+ if argument.startswith("marian_")
715
+ }
716
+ kwargs_clip_vision = {
717
+ argument[len("clip_vision_") :]: value
718
+ for argument, value in kwargs.items()
719
+ if argument.startswith("clip_vision_")
720
+ }
721
+ # remove marian, clip_vision kwargs from kwargs
722
+ for key in kwargs_marian.keys():
723
+ del kwargs["marian_" + key]
724
+ for key in kwargs_clip_vision.keys():
725
+ del kwargs["clip_vision_" + key]
726
+
727
+ # Load and initialize the marian and clip_vision model
728
+ marian_model = kwargs_marian.pop("model", None)
729
+ if marian_model is None:
730
+ assert (
731
+ marian_model_name_or_path is not None
732
+ ), "If `model` is not defined as an argument, a `marian_model_name_or_path` has to be defined"
733
+
734
+ if "config" not in kwargs_marian:
735
+ marian_config = MarianConfig.from_pretrained(marian_model_name_or_path)
736
+ kwargs_marian["config"] = marian_config
737
+
738
+ marian_model = FlaxMarianModel.from_pretrained(
739
+ marian_model_name_or_path, *model_args, **kwargs_marian,from_pt=True
740
+ )
741
+ clip_vision_model = kwargs_clip_vision.pop("model", None)
742
+ if clip_vision_model is None:
743
+ assert (
744
+ clip_vision_model_name_or_path is not None
745
+ ), "If `model` is not defined as an argument, a `clip_vision_model_name_or_path` has to be defined"
746
+
747
+ if "config" not in kwargs_clip_vision:
748
+ clip_vision_config = CLIPVisionConfig.from_pretrained(
749
+ clip_vision_model_name_or_path
750
+ )
751
+ kwargs_clip_vision["config"] = clip_vision_config
752
+
753
+ clip_vision_model = FlaxCLIPVisionModel.from_pretrained(
754
+ clip_vision_model_name_or_path, *model_args, **kwargs_clip_vision
755
+ )
756
+
757
+ # instantiate config with corresponding kwargs
758
+ dtype = kwargs.pop("dtype", jnp.float32)
759
+ config = CLIPVisionMarianConfig.from_clip_vision_marian_configs(
760
+ clip_vision_model.config, marian_model.config, **kwargs
761
+ )
762
+
763
+ # init model
764
+ model = cls(config, *model_args, dtype=dtype, **kwargs)
765
+ model.params["model"]["encoder"] = clip_vision_model.params
766
+ model.params["model"]["decoder"] = marian_model.params["decoder"]
767
+ model.params["model"]["shared"] = marian_model.params["shared"]
768
+ # model.params["marian_model"] = marian_model.params
769
+
770
+ return model
771
+
772
+
773
+ # flax_clip_vision_marian_cg = FlaxCLIPVisionmarianForConditionalGeneration.from_clip_vision_marian_pretrained('openai/clip-vit-base-patch32', 'facebook/marian-large')
774
+ # outputs = flax_clip_vision_marian_cg(pixel_values, input_ids, attention_mask, position_ids, output_hidden_states=True)
775
+ # flax_vit_bart_cg.generate(input_ids=pixel_values, decoder_start_token_id=tokenizer.lang_code_to_id['en_XX'])s
776
+ #flax_clip_vision_marian_cg = FlaxCLIPVisionMarianForConditionalGeneration.from_clip_vision_marian_pretrained('openai/clip-vit-base-patch32','Helsinki-NLP/opus-mt-en-id')
flax_clip_vision_marian/modeling_clip_vision_utils.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NEW
2
+
3
+ import os
4
+
5
+ # from functools import partial
6
+ from pickle import UnpicklingError
7
+ from typing import Dict, Set, Tuple, Union
8
+
9
+ import flax.linen as nn
10
+ import jax
11
+ import jax.numpy as jnp
12
+ from flax.core.frozen_dict import FrozenDict, unfreeze
13
+ from flax.serialization import from_bytes, to_bytes
14
+ from flax.traverse_util import flatten_dict, unflatten_dict
15
+ from jax.random import PRNGKey
16
+ from transformers.configuration_utils import PretrainedConfig
17
+ from transformers.file_utils import (
18
+ FLAX_WEIGHTS_NAME,
19
+ WEIGHTS_NAME,
20
+ PushToHubMixin,
21
+ cached_path,
22
+ hf_bucket_url,
23
+ is_offline_mode,
24
+ is_remote_url,
25
+ )
26
+ from transformers.modeling_flax_pytorch_utils import (
27
+ load_pytorch_checkpoint_in_flax_state_dict,
28
+ )
29
+ from transformers.utils import logging
30
+
31
+ from .generation_clip_vision_utils import FlaxCLIPVisionMarianGenerationMixin
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class FlaxCLIPVisionMarianPreTrainedModel(
37
+ PushToHubMixin, FlaxCLIPVisionMarianGenerationMixin
38
+ ):
39
+ r"""
40
+ Base class for all models.
41
+ :class:`~transformers.FlaxPreTrainedModel` takes care of storing the configuration of the models and handles
42
+ methods for loading, downloading and saving models.
43
+ Class attributes (overridden by derived classes):
44
+ - **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of
45
+ :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
46
+ - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
47
+ derived classes of the same architecture adding modules on top of the base model.
48
+ """
49
+ config_class = None
50
+ base_model_prefix = ""
51
+
52
+ def __init__(
53
+ self,
54
+ config: PretrainedConfig,
55
+ module: nn.Module,
56
+ input_shape: Tuple = (1, 1),
57
+ seed: int = 0,
58
+ dtype: jnp.dtype = jnp.float32,
59
+ ):
60
+ if config is None:
61
+ raise ValueError("config cannot be None")
62
+
63
+ if module is None:
64
+ raise ValueError("module cannot be None")
65
+
66
+ # Those are private to be exposed as typed property on derived classes.
67
+ self._config = config
68
+ self._module = module
69
+
70
+ # Those are public as their type is generic to every derived classes.
71
+ self.key = PRNGKey(seed)
72
+ self.dtype = dtype
73
+
74
+ # randomly initialized parameters
75
+ random_params = self.init_weights(self.key, input_shape)
76
+
77
+ # save required_params as set
78
+ self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
79
+ self.params = random_params
80
+
81
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> Dict:
82
+ raise NotImplementedError(f"init method has to be implemented for {self}")
83
+
84
+ @classmethod
85
+ def _from_config(cls, config, **kwargs):
86
+ """
87
+ All context managers that the model should be initialized under go here.
88
+ """
89
+ return cls(config, **kwargs)
90
+
91
+ @property
92
+ def config(self) -> PretrainedConfig:
93
+ return self._config
94
+
95
+ @property
96
+ def module(self) -> nn.Module:
97
+ return self._module
98
+
99
+ @property
100
+ def params(self) -> Union[Dict, FrozenDict]:
101
+ return self._params
102
+
103
+ @property
104
+ def required_params(self) -> Set:
105
+ return self._required_params
106
+
107
+ @params.setter
108
+ def params(self, params: Union[Dict, FrozenDict]):
109
+ if isinstance(params, FrozenDict):
110
+ params = unfreeze(params)
111
+ param_keys = set(flatten_dict(params).keys())
112
+ if len(self.required_params - param_keys) > 0:
113
+ raise ValueError(
114
+ "Some parameters are missing. Make sure that `params` include the following "
115
+ f"parameters {self.required_params - param_keys}"
116
+ )
117
+ self._params = params
118
+
119
+ @classmethod
120
+ def from_pretrained(
121
+ cls,
122
+ pretrained_model_name_or_path: Union[str, os.PathLike],
123
+ dtype: jnp.dtype = jnp.float32,
124
+ *model_args,
125
+ **kwargs,
126
+ ):
127
+
128
+ r"""
129
+ Instantiate a pretrained flax model from a pre-trained model configuration.
130
+ The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
131
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
132
+ task.
133
+ The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
134
+ weights are discarded.
135
+ Parameters:
136
+ pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
137
+ Can be either:
138
+ - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
139
+ Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
140
+ a user or organization name, like ``dbmdz/bert-base-german-cased``.
141
+ - A path to a `directory` containing model weights saved using
142
+ :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
143
+ - A path or url to a `pt index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In this
144
+ case, ``from_pt`` should be set to :obj:`True`.
145
+ model_args (sequence of positional arguments, `optional`):
146
+ All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
147
+ config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`):
148
+ Can be either:
149
+ - an instance of a class derived from :class:`~transformers.PretrainedConfig`,
150
+ - a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
151
+ Configuration for the model to use instead of an automatically loaded configuation. Configuration can
152
+ be automatically loaded when:
153
+ - The model is a model provided by the library (loaded with the `model id` string of a pretrained
154
+ model).
155
+ - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
156
+ by supplying the save directory.
157
+ - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
158
+ configuration JSON file named `config.json` is found in the directory.
159
+ cache_dir (:obj:`Union[str, os.PathLike]`, `optional`):
160
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
161
+ standard cache should not be used.
162
+ from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
163
+ Load the model weights from a PyTorch checkpoint save file (see docstring of
164
+ ``pretrained_model_name_or_path`` argument).
165
+ force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
166
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
167
+ cached versions if they exist.
168
+ resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
169
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
170
+ file exists.
171
+ proxies (:obj:`Dict[str, str], `optional`):
172
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
173
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
174
+ local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
175
+ Whether or not to only look at local files (i.e., do not try to download the model).
176
+ revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
177
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
178
+ git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
179
+ identifier allowed by git.
180
+ kwargs (remaining dictionary of keyword arguments, `optional`):
181
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
182
+ :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
183
+ automatically loaded:
184
+ - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
185
+ underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
186
+ already been done)
187
+ - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
188
+ initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
189
+ ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
190
+ with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
191
+ attribute will be passed to the underlying model's ``__init__`` function.
192
+ Examples::
193
+ >>> from transformers import BertConfig, FlaxBertModel
194
+ >>> # Download model and configuration from huggingface.co and cache.
195
+ >>> model = FlaxBertModel.from_pretrained('bert-base-cased')
196
+ >>> # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
197
+ >>> model = FlaxBertModel.from_pretrained('./test/saved_model/')
198
+ >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
199
+ >>> config = BertConfig.from_json_file('./pt_model/config.json')
200
+ >>> model = FlaxBertModel.from_pretrained('./pt_model/pytorch_model.bin', from_pt=True, config=config)
201
+ """
202
+ config = kwargs.pop("config", None)
203
+ cache_dir = kwargs.pop("cache_dir", None)
204
+ from_pt = kwargs.pop("from_pt", False)
205
+ force_download = kwargs.pop("force_download", False)
206
+ resume_download = kwargs.pop("resume_download", False)
207
+ proxies = kwargs.pop("proxies", None)
208
+ local_files_only = kwargs.pop("local_files_only", False)
209
+ use_auth_token = kwargs.pop("use_auth_token", None)
210
+ revision = kwargs.pop("revision", None)
211
+ from_pipeline = kwargs.pop("_from_pipeline", None)
212
+ from_auto_class = kwargs.pop("_from_auto", False)
213
+
214
+ user_agent = {
215
+ "file_type": "model",
216
+ "framework": "flax",
217
+ "from_auto_class": from_auto_class,
218
+ }
219
+ if from_pipeline is not None:
220
+ user_agent["using_pipeline"] = from_pipeline
221
+
222
+ if is_offline_mode() and not local_files_only:
223
+ logger.info("Offline mode: forcing local_files_only=True")
224
+ local_files_only = True
225
+
226
+ # Load config if we don't provide a configuration
227
+ if not isinstance(config, PretrainedConfig):
228
+ config_path = (
229
+ config if config is not None else pretrained_model_name_or_path
230
+ )
231
+ config, model_kwargs = cls.config_class.from_pretrained(
232
+ config_path,
233
+ *model_args,
234
+ cache_dir=cache_dir,
235
+ return_unused_kwargs=True,
236
+ force_download=force_download,
237
+ resume_download=resume_download,
238
+ proxies=proxies,
239
+ local_files_only=local_files_only,
240
+ use_auth_token=use_auth_token,
241
+ revision=revision,
242
+ _from_auto=from_auto_class,
243
+ _from_pipeline=from_pipeline,
244
+ **kwargs,
245
+ )
246
+ else:
247
+ model_kwargs = kwargs
248
+
249
+ # Add the dtype to model_kwargs
250
+ model_kwargs["dtype"] = dtype
251
+
252
+ # Load model
253
+ if pretrained_model_name_or_path is not None:
254
+ if os.path.isdir(pretrained_model_name_or_path):
255
+ if from_pt and os.path.isfile(
256
+ os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
257
+ ):
258
+ # Load from a PyTorch checkpoint
259
+ archive_file = os.path.join(
260
+ pretrained_model_name_or_path, WEIGHTS_NAME
261
+ )
262
+ elif os.path.isfile(
263
+ os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
264
+ ):
265
+ # Load from a Flax checkpoint
266
+ archive_file = os.path.join(
267
+ pretrained_model_name_or_path, FLAX_WEIGHTS_NAME
268
+ )
269
+ else:
270
+ raise EnvironmentError(
271
+ f"Error no file named {[FLAX_WEIGHTS_NAME, WEIGHTS_NAME]} found in directory "
272
+ f"{pretrained_model_name_or_path} or `from_pt` set to False"
273
+ )
274
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(
275
+ pretrained_model_name_or_path
276
+ ):
277
+ archive_file = pretrained_model_name_or_path
278
+ else:
279
+ archive_file = hf_bucket_url(
280
+ pretrained_model_name_or_path,
281
+ filename=WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME,
282
+ revision=revision,
283
+ )
284
+
285
+ # redirect to the cache, if necessary
286
+ try:
287
+ resolved_archive_file = cached_path(
288
+ archive_file,
289
+ cache_dir=cache_dir,
290
+ force_download=force_download,
291
+ proxies=proxies,
292
+ resume_download=resume_download,
293
+ local_files_only=local_files_only,
294
+ use_auth_token=use_auth_token,
295
+ user_agent=user_agent,
296
+ )
297
+ except EnvironmentError as err:
298
+ logger.error(err)
299
+ msg = (
300
+ f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
301
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
302
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named {WEIGHTS_NAME}.\n\n"
303
+ )
304
+ raise EnvironmentError(msg)
305
+
306
+ if resolved_archive_file == archive_file:
307
+ logger.info(f"loading weights file {archive_file}")
308
+ else:
309
+ logger.info(
310
+ f"loading weights file {archive_file} from cache at {resolved_archive_file}"
311
+ )
312
+ else:
313
+ resolved_archive_file = None
314
+
315
+ # init random models
316
+ model = cls(config, *model_args, **model_kwargs)
317
+
318
+ if from_pt:
319
+ state = load_pytorch_checkpoint_in_flax_state_dict(
320
+ model, resolved_archive_file
321
+ )
322
+ else:
323
+ with open(resolved_archive_file, "rb") as state_f:
324
+ try:
325
+ state = from_bytes(cls, state_f.read())
326
+ except UnpicklingError:
327
+ raise EnvironmentError(
328
+ f"Unable to convert {archive_file} to Flax deserializable object. "
329
+ )
330
+ # make sure all arrays are stored as jnp.arrays
331
+ # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
332
+ # https://github.com/google/flax/issues/1261
333
+ state = jax.tree_util.tree_map(jnp.array, state)
334
+
335
+ # if model is base model only use model_prefix key
336
+ if (
337
+ cls.base_model_prefix not in dict(model.params)
338
+ and cls.base_model_prefix in state
339
+ ):
340
+ state = state[cls.base_model_prefix]
341
+
342
+ # if model is head model and we are loading weights from base model
343
+ # we initialize new params dict with base_model_prefix
344
+ if (
345
+ cls.base_model_prefix in dict(model.params)
346
+ and cls.base_model_prefix not in state
347
+ ):
348
+ state = {cls.base_model_prefix: state}
349
+
350
+ # flatten dicts
351
+ state = flatten_dict(state)
352
+
353
+ random_state = flatten_dict(unfreeze(model.params))
354
+
355
+ missing_keys = model.required_params - set(state.keys())
356
+ unexpected_keys = set(state.keys()) - model.required_params
357
+
358
+ # add missing keys as random parameters
359
+ for missing_key in missing_keys:
360
+ state[missing_key] = random_state[missing_key]
361
+
362
+ # remove unexpected keys to not be saved again
363
+ for unexpected_key in unexpected_keys:
364
+ del state[unexpected_key]
365
+
366
+ if len(unexpected_keys) > 0:
367
+ logger.warning(
368
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
369
+ f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
370
+ f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
371
+ f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n"
372
+ f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
373
+ f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
374
+ )
375
+ else:
376
+ logger.info(
377
+ f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
378
+ )
379
+
380
+ if len(missing_keys) > 0:
381
+ logger.warning(
382
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
383
+ f"and are newly initialized: {missing_keys}\n"
384
+ f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
385
+ )
386
+ else:
387
+ logger.info(
388
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
389
+ f"If your task is similar to the task the model of the checkpoint was trained on, "
390
+ f"you can already use {model.__class__.__name__} for predictions without further training."
391
+ )
392
+
393
+ # set correct parameters
394
+ model.params = unflatten_dict(state)
395
+
396
+ return model
397
+
398
+ def save_pretrained(
399
+ self,
400
+ save_directory: Union[str, os.PathLike],
401
+ params=None,
402
+ push_to_hub=False,
403
+ **kwargs,
404
+ ):
405
+ """
406
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
407
+ `:func:`~transformers.FlaxPreTrainedModel.from_pretrained`` class method
408
+ Arguments:
409
+ save_directory (:obj:`str` or :obj:`os.PathLike`):
410
+ Directory to which to save. Will be created if it doesn't exist.
411
+ push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
412
+ Whether or not to push your model to the Hugging Face model hub after saving it.
413
+ .. warning::
414
+ Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with
415
+ :obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are
416
+ pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory
417
+ instead.
418
+ kwargs:
419
+ Additional key word arguments passed along to the
420
+ :meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
421
+ """
422
+ if os.path.isfile(save_directory):
423
+ logger.error(
424
+ f"Provided path ({save_directory}) should be a directory, not a file"
425
+ )
426
+ return
427
+
428
+ if push_to_hub:
429
+ commit_message = kwargs.pop("commit_message", None)
430
+ repo = self._create_or_get_repo(save_directory, **kwargs)
431
+
432
+ os.makedirs(save_directory, exist_ok=True)
433
+
434
+ # get abs dir
435
+ save_directory = os.path.abspath(save_directory)
436
+ # save config as well
437
+ self.config.architectures = [self.__class__.__name__[4:]]
438
+ self.config.save_pretrained(save_directory)
439
+
440
+ # save model
441
+ output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME)
442
+ with open(output_model_file, "wb") as f:
443
+ params = params if params is not None else self.params
444
+ model_bytes = to_bytes(params)
445
+ f.write(model_bytes)
446
+
447
+ logger.info(f"Model weights saved in {output_model_file}")
448
+
449
+ if push_to_hub:
450
+ url = self._push_to_hub(repo, commit_message=commit_message)
451
+ logger.info(f"Model pushed to the hub in this commit: {url}")
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pillow==8.2.0
2
+ aiofiles==0.4.0
3
+ numpy
4
+ pandas
5
+ git+https://github.com/huggingface/transformers.git
6
+ datasets >= 1.1.3
7
+ jax>=0.2.8
8
+ jaxlib>=0.1.59
9
+ flax>=0.3.4
10
+ optax>=0.0.9
11
+ sentencepiece
12
+ mtranslate
13
+ streamlit