santiago commited on
Commit
36fdb4d
1 Parent(s): 09ad451

feat: add baseline code

Browse files
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ jax>=0.2.8
2
+ jaxlib>=0.1.59
3
+ flax>=0.3.4
4
+ optax>=0.0.8
5
+ -f https://download.pytorch.org/whl/torch_stable.html
6
+ torch==1.9.0+cpu
7
+ -f https://download.pytorch.org/whl/torch_stable.html
8
+ torchvision==0.10.0+cpu
src/configuration_medclip.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class HybridCLIPConfig(PretrainedConfig):
11
+ r"""
12
+ :class:`HybridCLIPConfig` is the configuration class to store the configuration of a
13
+ :class:`~HybridCLIPModel`. It is used to instantiate HybridCLIPModel model according to the specified arguments,
14
+ defining the text model and vision model configs.
15
+
16
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
17
+ outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
18
+
19
+ Args:
20
+ text_config_dict (:obj:`dict`):
21
+ Dictionary of configuration options that defines text model config.
22
+ vision_config_dict (:obj:`dict`):
23
+ Dictionary of configuration options that defines vison model config.
24
+ projection_dim (:obj:`int`, `optional`, defaults to 512):
25
+ Dimentionality of text and vision projection layers.
26
+ kwargs (`optional`):
27
+ Dictionary of keyword arguments.
28
+
29
+ Examples::
30
+
31
+ >>> from transformers import BertConfig, CLIPConfig, HybridCLIPConfig, FlaxHybridCLIP
32
+
33
+ >>> # Initializing a BERT and CLIP configuration
34
+ >>> config_text = BertConfig()
35
+ >>> config_vision = CLIPConfig()
36
+
37
+ >>> config = HybridCLIPConfig.from_text_vision_configs(config_text, config_vision, projection_dim=512)
38
+
39
+ >>> # Initializing a BERT and CLIPVision model
40
+ >>> model = EncoderDecoderModel(config=config)
41
+
42
+ >>> # Accessing the model configuration
43
+ >>> config_text = model.config.text_config
44
+ >>> config_vision = model.config.vision_config
45
+
46
+ >>> # Saving the model, including its configuration
47
+ >>> model.save_pretrained('my-model')
48
+
49
+ >>> # loading model and config from pretrained folder
50
+ >>> encoder_decoder_config = HybridCLIPConfig.from_pretrained('my-model')
51
+ >>> model = FlaxHybridCLIP.from_pretrained('my-model', config=encoder_decoder_config)
52
+ """
53
+
54
+ model_type = "hybrid-clip"
55
+ is_composition = True
56
+
57
+ def __init__(self, projection_dim=512, **kwargs):
58
+ super().__init__(**kwargs)
59
+
60
+ if "text_config" not in kwargs:
61
+ raise ValueError("`text_config` can not be `None`.")
62
+
63
+ if "vision_config" not in kwargs:
64
+ raise ValueError("`vision_config` can not be `None`.")
65
+
66
+ text_config = kwargs.pop("text_config")
67
+ vision_config = kwargs.pop("vision_config")
68
+
69
+ text_model_type = text_config.pop("model_type")
70
+ vision_model_type = vision_config.pop("model_type")
71
+
72
+ from transformers import AutoConfig
73
+
74
+ self.text_config = AutoConfig.for_model(text_model_type, **text_config)
75
+
76
+ if vision_model_type == "clip":
77
+ self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
78
+ else:
79
+ self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
80
+
81
+ self.projection_dim = projection_dim
82
+ self.initializer_factor = 1.0
83
+
84
+ @classmethod
85
+ def from_text_vision_configs(cls, text_config: PretrainedConfig, vision_config: PretrainedConfig, **kwargs):
86
+ r"""
87
+ Instantiate a :class:`HybridCLIPConfig` (or a derived class) from text model configuration and
88
+ vision model configuration.
89
+
90
+ Returns:
91
+ :class:`HybridCLIPConfig`: An instance of a configuration object
92
+ """
93
+
94
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
95
+
96
+ def to_dict(self):
97
+ """
98
+ Serializes this instance to a Python dictionary. Override the default
99
+ :meth:`~transformers.PretrainedConfig.to_dict`.
100
+
101
+ Returns:
102
+ :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
103
+ """
104
+ output = copy.deepcopy(self.__dict__)
105
+ output["text_config"] = self.text_config.to_dict()
106
+ output["vision_config"] = self.vision_config.to_dict()
107
+ output["model_type"] = self.__class__.model_type
108
+ return output
src/modeling_medclip.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional, Tuple
17
+
18
+ import flax.linen as nn
19
+ import jax
20
+ import jax.numpy as jnp
21
+ from configuration_hybrid_clip import HybridCLIPConfig
22
+ from flax.core.frozen_dict import FrozenDict
23
+ from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
24
+ from transformers.modeling_flax_utils import FlaxPreTrainedModel
25
+ from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
26
+ from transformers.utils import logging
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class FlaxHybridCLIPModule(nn.Module):
33
+ config: HybridCLIPConfig
34
+ dtype: jnp.dtype = jnp.float32
35
+
36
+ def setup(self):
37
+ text_config = self.config.text_config
38
+ vision_config = self.config.vision_config
39
+
40
+ self.projection_dim = self.config.projection_dim
41
+ self.text_embed_dim = text_config.hidden_size
42
+ self.vision_embed_dim = vision_config.hidden_size
43
+
44
+ text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class
45
+ vision_module = FLAX_MODEL_MAPPING.get(self.config.vision_config.__class__, FlaxCLIPVisionModel).module_class
46
+
47
+ self.text_model = text_module(text_config, dtype=self.dtype)
48
+ self.vision_model = vision_module(vision_config, dtype=self.dtype)
49
+
50
+ self.visual_projection = nn.Dense(
51
+ self.projection_dim,
52
+ dtype=self.dtype,
53
+ kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
54
+ use_bias=False,
55
+ )
56
+ self.text_projection = nn.Dense(
57
+ self.projection_dim,
58
+ dtype=self.dtype,
59
+ kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
60
+ use_bias=False,
61
+ )
62
+ self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])
63
+
64
+ def __call__(
65
+ self,
66
+ input_ids=None,
67
+ pixel_values=None,
68
+ attention_mask=None,
69
+ position_ids=None,
70
+ token_type_ids=None,
71
+ deterministic: bool = True,
72
+ output_attentions=None,
73
+ output_hidden_states=None,
74
+ return_dict=None,
75
+ ):
76
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
77
+
78
+ vision_outputs = self.vision_model(
79
+ pixel_values=pixel_values,
80
+ deterministic=deterministic,
81
+ output_attentions=output_attentions,
82
+ output_hidden_states=output_hidden_states,
83
+ return_dict=return_dict,
84
+ )
85
+
86
+ text_outputs = self.text_model(
87
+ input_ids=input_ids,
88
+ attention_mask=attention_mask,
89
+ token_type_ids=token_type_ids,
90
+ position_ids=position_ids,
91
+ deterministic=deterministic,
92
+ output_attentions=output_attentions,
93
+ output_hidden_states=output_hidden_states,
94
+ return_dict=return_dict,
95
+ )
96
+
97
+ image_embeds = vision_outputs[1]
98
+ image_embeds = self.visual_projection(image_embeds)
99
+
100
+ text_embeds = text_outputs[1]
101
+ text_embeds = self.text_projection(text_embeds)
102
+
103
+ # normalized features
104
+ image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
105
+ text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
106
+
107
+ # cosine similarity as logits
108
+ logit_scale = jnp.exp(self.logit_scale)
109
+ logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale
110
+ logits_per_image = logits_per_text.T
111
+
112
+ if not return_dict:
113
+ return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
114
+
115
+ return FlaxCLIPOutput(
116
+ logits_per_image=logits_per_image,
117
+ logits_per_text=logits_per_text,
118
+ text_embeds=text_embeds,
119
+ image_embeds=image_embeds,
120
+ text_model_output=text_outputs,
121
+ vision_model_output=vision_outputs,
122
+ )
123
+
124
+
125
+ class FlaxHybridCLIP(FlaxPreTrainedModel):
126
+ config_class = HybridCLIPConfig
127
+ module_class = FlaxHybridCLIPModule
128
+
129
+ def __init__(
130
+ self,
131
+ config: HybridCLIPConfig,
132
+ input_shape: Optional[Tuple] = None,
133
+ seed: int = 0,
134
+ dtype: jnp.dtype = jnp.float32,
135
+ **kwargs
136
+ ):
137
+ if input_shape is None:
138
+ input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
139
+
140
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
141
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
142
+
143
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
144
+ # init input tensor
145
+ input_ids = jnp.zeros(input_shape[0], dtype="i4")
146
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
147
+ token_type_ids = jnp.ones_like(input_ids)
148
+ attention_mask = jnp.ones_like(input_ids)
149
+
150
+ pixel_values = jax.random.normal(rng, input_shape[1])
151
+
152
+ params_rng, dropout_rng = jax.random.split(rng)
153
+ rngs = {"params": params_rng, "dropout": dropout_rng}
154
+
155
+ return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)["params"]
156
+
157
+ def __call__(
158
+ self,
159
+ input_ids,
160
+ pixel_values,
161
+ attention_mask=None,
162
+ position_ids=None,
163
+ token_type_ids=None,
164
+ params: dict = None,
165
+ dropout_rng: jax.random.PRNGKey = None,
166
+ train: bool = False,
167
+ output_attentions: Optional[bool] = None,
168
+ output_hidden_states: Optional[bool] = None,
169
+ return_dict: Optional[bool] = None,
170
+ ):
171
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
172
+ output_hidden_states = (
173
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
174
+ )
175
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
176
+
177
+ if position_ids is None:
178
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
179
+
180
+ if token_type_ids is None:
181
+ token_type_ids = jnp.zeros_like(input_ids)
182
+
183
+ if attention_mask is None:
184
+ attention_mask = jnp.ones_like(input_ids)
185
+
186
+ # Handle any PRNG if needed
187
+ rngs = {}
188
+ if dropout_rng is not None:
189
+ rngs["dropout"] = dropout_rng
190
+
191
+ return self.module.apply(
192
+ {"params": params or self.params},
193
+ jnp.array(input_ids, dtype="i4"),
194
+ jnp.array(pixel_values, dtype=jnp.float32),
195
+ jnp.array(attention_mask, dtype="i4"),
196
+ jnp.array(position_ids, dtype="i4"),
197
+ jnp.array(token_type_ids, dtype="i4"),
198
+ not train,
199
+ output_attentions,
200
+ output_hidden_states,
201
+ return_dict,
202
+ rngs=rngs,
203
+ )
204
+
205
+ def get_text_features(
206
+ self,
207
+ input_ids,
208
+ attention_mask=None,
209
+ position_ids=None,
210
+ token_type_ids=None,
211
+ dropout_rng: jax.random.PRNGKey = None,
212
+ train=False,
213
+ ):
214
+ r"""
215
+ Args:
216
+ input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`):
217
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
218
+ provide it.
219
+
220
+ Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
221
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
222
+ for details.
223
+
224
+ `What are input IDs? <../glossary.html#input-ids>`__
225
+
226
+ Returns:
227
+ text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings
228
+ obtained by applying the projection layer to the pooled output of text model.
229
+ """
230
+ if position_ids is None:
231
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
232
+
233
+ if token_type_ids is None:
234
+ token_type_ids = jnp.zeros_like(input_ids)
235
+
236
+ if attention_mask is None:
237
+ attention_mask = jnp.ones_like(input_ids)
238
+
239
+ # Handle any PRNG if needed
240
+ rngs = {}
241
+ if dropout_rng is not None:
242
+ rngs["dropout"] = dropout_rng
243
+
244
+ def _get_features(module, input_ids, attention_mask, position_ids, token_type_ids, deterministic):
245
+ text_outputs = module.text_model(
246
+ input_ids=input_ids,
247
+ attention_mask=attention_mask,
248
+ position_ids=position_ids,
249
+ token_type_ids=token_type_ids,
250
+ deterministic=deterministic,
251
+ )
252
+ pooled_output = text_outputs[1]
253
+ text_features = module.text_projection(pooled_output)
254
+ return text_features
255
+
256
+ return self.module.apply(
257
+ {"params": self.params},
258
+ jnp.array(input_ids, dtype="i4"),
259
+ jnp.array(attention_mask, dtype="i4"),
260
+ jnp.array(position_ids, dtype="i4"),
261
+ jnp.array(token_type_ids, dtype="i4"),
262
+ not train,
263
+ method=_get_features,
264
+ rngs=rngs,
265
+ )
266
+
267
+ def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False):
268
+ r"""
269
+ Args:
270
+ pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
271
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained
272
+ using :class:`~transformers.ImageFeatureExtractionMixin`. See
273
+ :meth:`transformers.ImageFeatureExtractionMixin.__call__` for details.
274
+
275
+ Returns:
276
+ image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings
277
+ obtained by applying the projection layer to the pooled output of vision model.
278
+ """
279
+
280
+ # Handle any PRNG if needed
281
+ rngs = {}
282
+ if dropout_rng is not None:
283
+ rngs["dropout"] = dropout_rng
284
+
285
+ def _get_features(module, pixel_values, deterministic):
286
+ vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic)
287
+ pooled_output = vision_outputs[1] # pooled_output
288
+ image_features = module.visual_projection(pooled_output)
289
+ return image_features
290
+
291
+ return self.module.apply(
292
+ {"params": self.params},
293
+ jnp.array(pixel_values, dtype=jnp.float32),
294
+ not train,
295
+ method=_get_features,
296
+ rngs=rngs,
297
+ )
298
+
299
+ @classmethod
300
+ def from_text_vision_pretrained(
301
+ cls,
302
+ text_model_name_or_path: str = None,
303
+ vision_model_name_or_path: str = None,
304
+ *model_args,
305
+ **kwargs,
306
+ ) -> FlaxPreTrainedModel:
307
+ """
308
+ Params:
309
+ text_model_name_or_path (:obj: `str`, `optional`):
310
+ Information necessary to initiate the text model. Can be either:
311
+
312
+ - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
313
+ Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
314
+ a user or organization name, like ``dbmdz/bert-base-german-cased``.
315
+ - A path to a `directory` containing model weights saved using
316
+ :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
317
+ - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
318
+ this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
319
+ as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
320
+ a Flax model using the provided conversion scripts and loading the Flax model afterwards.
321
+
322
+ vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
323
+ Information necessary to initiate the vision model. Can be either:
324
+
325
+ - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
326
+ Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
327
+ a user or organization name, like ``dbmdz/bert-base-german-cased``.
328
+ - A path to a `directory` containing model weights saved using
329
+ :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
330
+ - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
331
+ this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
332
+ as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
333
+ a Flax model using the provided conversion scripts and loading the Flax model afterwards.
334
+
335
+ model_args (remaining positional arguments, `optional`):
336
+ All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
337
+
338
+ kwargs (remaining dictionary of keyword arguments, `optional`):
339
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
340
+ :obj:`output_attentions=True`).
341
+
342
+ - To update the text configuration, use the prefix `text_` for each configuration parameter.
343
+ - To update the vision configuration, use the prefix `vision_` for each configuration parameter.
344
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
345
+
346
+ Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.
347
+
348
+ Example::
349
+
350
+ >>> from transformers import FlaxHybridCLIP
351
+ >>> # initialize a model from pretrained BERT and CLIP models. Note that the projection layers will be randomly initialized.
352
+ >>> # If using CLIP's vision model the vision projection layer will be initialized using pre-trained weights
353
+ >>> model = FlaxHybridCLIP.from_text_vision_pretrained('bert-base-uncased', 'openai/clip-vit-base-patch32')
354
+ >>> # saving model after fine-tuning
355
+ >>> model.save_pretrained("./bert-clip")
356
+ >>> # load fine-tuned model
357
+ >>> model = FlaxHybridCLIP.from_pretrained("./bert-clip")
358
+ """
359
+
360
+ kwargs_text = {
361
+ argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_")
362
+ }
363
+
364
+ kwargs_vision = {
365
+ argument[len("vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_")
366
+ }
367
+
368
+ # remove text, vision kwargs from kwargs
369
+ for key in kwargs_text.keys():
370
+ del kwargs["text_" + key]
371
+ for key in kwargs_vision.keys():
372
+ del kwargs["vision_" + key]
373
+
374
+ # Load and initialize the text and vision model
375
+ text_model = kwargs_text.pop("model", None)
376
+ if text_model is None:
377
+ assert (
378
+ text_model_name_or_path is not None
379
+ ), "If `model` is not defined as an argument, a `text_model_name_or_path` has to be defined"
380
+ from transformers import FlaxAutoModel
381
+
382
+ if "config" not in kwargs_text:
383
+ from transformers import AutoConfig
384
+
385
+ text_config = AutoConfig.from_pretrained(text_model_name_or_path)
386
+ kwargs_text["config"] = text_config
387
+
388
+ text_model = FlaxAutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text)
389
+
390
+ vision_model = kwargs_vision.pop("model", None)
391
+ if vision_model is None:
392
+ assert (
393
+ vision_model_name_or_path is not None
394
+ ), "If `model` is not defined as an argument, a `vision_model_name_or_path` has to be defined"
395
+ from transformers import FlaxAutoModel
396
+
397
+ if "config" not in kwargs_vision:
398
+ from transformers import AutoConfig
399
+
400
+ vision_config = AutoConfig.from_pretrained(vision_model_name_or_path)
401
+ kwargs_vision["config"] = vision_config
402
+
403
+ vision_model = FlaxAutoModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision)
404
+
405
+ # instantiate config with corresponding kwargs
406
+ dtype = kwargs.pop("dtype", jnp.float32)
407
+ config = HybridCLIPConfig.from_text_vision_configs(text_model.config, vision_model.config, **kwargs)
408
+
409
+ # init model
410
+ model = cls(config, *model_args, dtype=dtype, **kwargs)
411
+
412
+ if vision_config.model_type == "clip":
413
+ model.params["vision_model"]["vision_model"] = vision_model.params["vision_model"]
414
+ model.params["visual_projection"]["kernel"] = vision_model.params["visual_projection"]["kernel"]
415
+ else:
416
+ model.params["vision_model"] = vision_model.params
417
+
418
+ model.params["text_model"] = text_model.params
419
+
420
+ return model
src/run_medclip.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Training a CLIP like dual encoder models using text and vision encoders in the library.
18
+
19
+ The script can be used to train CLIP like models for languages other than english by using
20
+ a text encoder pre-trained in the desired language. Currently this script support the following vision
21
+ and text models:
22
+ Vision models: ViT(https://huggingface.co/models?filter=vit), CLIP (https://huggingface.co/models?filter=clip)
23
+ Text models: BERT, ROBERTa (https://huggingface.co/models?filter=masked-lm)
24
+ """
25
+
26
+ import json
27
+ import logging
28
+ import os
29
+ import sys
30
+ import time
31
+ from dataclasses import dataclass, field
32
+ from pathlib import Path
33
+ from typing import Callable, Optional
34
+
35
+ import torch
36
+ from torchvision.datasets import VisionDataset
37
+ from torchvision.io import ImageReadMode, read_image
38
+ from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
39
+ from torchvision.transforms.functional import InterpolationMode
40
+ from tqdm import tqdm
41
+
42
+ import jax
43
+ import jax.numpy as jnp
44
+ import optax
45
+ import transformers
46
+ from flax import jax_utils
47
+ from flax.jax_utils import unreplicate
48
+ from flax.training import train_state
49
+ from flax.training.common_utils import get_metrics, shard, shard_prng_key
50
+ from modeling_hybrid_clip import FlaxHybridCLIP
51
+ from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
52
+
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+ # Cache the result
57
+ has_tensorboard = is_tensorboard_available()
58
+ if has_tensorboard:
59
+ try:
60
+ from flax.metrics.tensorboard import SummaryWriter
61
+ except ImportError as ie:
62
+ has_tensorboard = False
63
+ print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
64
+
65
+ else:
66
+ print(
67
+ "Unable to display metrics through TensorBoard because the package is not installed: "
68
+ "Please run pip install tensorboard to enable."
69
+ )
70
+
71
+
72
+ @dataclass
73
+ class ModelArguments:
74
+ """
75
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
76
+ """
77
+
78
+ text_model_name_or_path: str = field(
79
+ metadata={
80
+ "help": "The text model checkpoint for weights initialization."
81
+ "Don't set if you want to train a model from scratch."
82
+ },
83
+ )
84
+ vision_model_name_or_path: str = field(
85
+ metadata={
86
+ "help": "The vision model checkpoint for weights initialization."
87
+ "Don't set if you want to train a model from scratch."
88
+ },
89
+ )
90
+ from_pt: bool = field(
91
+ default=True,
92
+ metadata={"help": "whether to load the text and vision model using PyTorch checkpoints."},
93
+ )
94
+ config_name: Optional[str] = field(
95
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
96
+ )
97
+ tokenizer_name: Optional[str] = field(
98
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
99
+ )
100
+ cache_dir: Optional[str] = field(
101
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
102
+ )
103
+ use_fast_tokenizer: bool = field(
104
+ default=True,
105
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
106
+ )
107
+ dtype: Optional[str] = field(
108
+ default="float32",
109
+ metadata={
110
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
111
+ },
112
+ )
113
+
114
+
115
+ @dataclass
116
+ class DataTrainingArguments:
117
+ """
118
+ Arguments pertaining to what data we are going to input our model for training and eval.
119
+ """
120
+
121
+ data_dir: Optional[str] = field(default=None, metadata={"help": "The data directory containing input files."})
122
+ train_file: Optional[str] = field(
123
+ default=None, metadata={"help": "The input training data file (a jsonlines file)."}
124
+ )
125
+ validation_file: Optional[str] = field(
126
+ default=None,
127
+ metadata={"help": "An optional input evaluation data file (a jsonlines file)."},
128
+ )
129
+ max_seq_length: Optional[int] = field(
130
+ default=72,
131
+ metadata={
132
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
133
+ "than this will be truncated, sequences shorter will be padded."
134
+ },
135
+ )
136
+ max_train_samples: Optional[int] = field(
137
+ default=None,
138
+ metadata={
139
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
140
+ "value if set."
141
+ },
142
+ )
143
+ max_eval_samples: Optional[int] = field(
144
+ default=None,
145
+ metadata={
146
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
147
+ "value if set."
148
+ },
149
+ )
150
+ overwrite_cache: bool = field(
151
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
152
+ )
153
+ overwrite_cache: bool = field(
154
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
155
+ )
156
+ preprocessing_num_workers: Optional[int] = field(
157
+ default=None,
158
+ metadata={"help": "The number of processes to use for the preprocessing."},
159
+ )
160
+
161
+ def __post_init__(self):
162
+ if self.train_file is None and self.validation_file is None:
163
+ raise ValueError("Need either a dataset name or a training/validation file.")
164
+ else:
165
+ if self.train_file is not None:
166
+ extension = self.train_file.split(".")[-1]
167
+ assert extension == "json", "`train_file` should be a json file."
168
+ if self.validation_file is not None:
169
+ extension = self.validation_file.split(".")[-1]
170
+ assert extension == "json", "`validation_file` should be a json file."
171
+
172
+
173
+ # We use torchvision for faster image pre-processing.
174
+ # We need to ensure faster processing speed as it can become a bottleneck on TPU
175
+ class Transform(torch.nn.Module):
176
+ def __init__(self, image_size):
177
+ super().__init__()
178
+ self.transforms = torch.nn.Sequential(
179
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC),
180
+ CenterCrop(image_size),
181
+ ConvertImageDtype(torch.float),
182
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
183
+ )
184
+
185
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
186
+ with torch.no_grad():
187
+ x = self.transforms(x)
188
+ return x
189
+
190
+
191
+ class ImageTextDataset(VisionDataset):
192
+ """
193
+ Dtaset for loading image-text data for tasks like CLIP training, Image Captioning.
194
+
195
+ Args:
196
+ root: (string): The root path where the dataset is stored
197
+ file_path: (string): Path to the file containing the image_paths and associated captions.
198
+ The expected format is jsonlines where each line is a json object containing to keys.
199
+ `image_path`: The path to the image.
200
+ `captions`: An `array` of captions.
201
+ transform (callable, optional): A function/transform that takes in an PIL image
202
+ and returns a transformed version. E.g, ``transforms.ToTensor``
203
+ target_transform (callable, optional): A function/transform that takes in the
204
+ target and transforms it.
205
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
206
+ and returns a transformed version.
207
+ """
208
+
209
+ def __init__(
210
+ self,
211
+ root: str,
212
+ file_path: str,
213
+ captions_per_image=2,
214
+ transform: Optional[Callable] = None,
215
+ target_transform: Optional[Callable] = None,
216
+ transforms: Optional[Callable] = None,
217
+ ):
218
+ super().__init__(root, transforms, transform, target_transform)
219
+
220
+ with open(file_path, "r") as f:
221
+ examples = [json.loads(line) for line in f.readlines()]
222
+
223
+ self.captions = []
224
+ self.image_paths = []
225
+
226
+ for example in examples:
227
+ self.captions.extend(example["captions"][:captions_per_image])
228
+ self.image_paths.extend([example["image_path"]] * captions_per_image)
229
+
230
+ def _load_image(self, idx: int):
231
+ path = self.image_paths[idx]
232
+ return read_image(path, mode=ImageReadMode.RGB)
233
+
234
+ def _load_target(self, idx):
235
+ return self.captions[idx]
236
+
237
+ def __getitem__(self, index: int):
238
+ image = self._load_image(index)
239
+ target = self._load_target(index)
240
+
241
+ if self.transforms is not None:
242
+ image, target = self.transforms(image, target)
243
+
244
+ return image, target
245
+
246
+ def __len__(self) -> int:
247
+ return len(self.captions)
248
+
249
+
250
+ class TrainState(train_state.TrainState):
251
+ dropout_rng: jnp.ndarray
252
+
253
+ def replicate(self):
254
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
255
+
256
+
257
+ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
258
+ summary_writer.scalar("train_time", train_time, step)
259
+
260
+ train_metrics = get_metrics(train_metrics)
261
+ for key, vals in train_metrics.items():
262
+ tag = f"train_{key}"
263
+ for i, val in enumerate(vals):
264
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
265
+
266
+ for metric_name, value in eval_metrics.items():
267
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
268
+
269
+
270
+ def create_learning_rate_fn(
271
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
272
+ ) -> Callable[[int], jnp.array]:
273
+ """Returns a linear warmup, linear_decay learning rate function."""
274
+ steps_per_epoch = train_ds_size // train_batch_size
275
+ num_train_steps = steps_per_epoch * num_train_epochs
276
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
277
+ decay_fn = optax.linear_schedule(
278
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
279
+ )
280
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
281
+ return schedule_fn
282
+
283
+
284
+ def main():
285
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
286
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
287
+ # If we pass only one argument to the script and it's the path to a json file,
288
+ # let's parse it to get our arguments.
289
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
290
+ else:
291
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
292
+
293
+ if (
294
+ os.path.exists(training_args.output_dir)
295
+ and os.listdir(training_args.output_dir)
296
+ and training_args.do_train
297
+ and not training_args.overwrite_output_dir
298
+ ):
299
+ raise ValueError(
300
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
301
+ "Use --overwrite_output_dir to overcome."
302
+ )
303
+
304
+ # Make one log on every process with the configuration for debugging.
305
+ logging.basicConfig(
306
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
307
+ datefmt="%m/%d/%Y %H:%M:%S",
308
+ level=logging.INFO,
309
+ )
310
+ # Setup logging, we only want one process per machine to log things on the screen.
311
+ logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
312
+ if jax.process_index() == 0:
313
+ transformers.utils.logging.set_verbosity_info()
314
+ else:
315
+ transformers.utils.logging.set_verbosity_error()
316
+
317
+ # Set the verbosity to info of the Transformers logger (on main process only):
318
+ logger.info(f"Training/evaluation parameters {training_args}")
319
+
320
+ if model_args.tokenizer_name:
321
+ tokenizer = AutoTokenizer.from_pretrained(
322
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
323
+ )
324
+ elif model_args.text_model_name_or_path:
325
+ tokenizer = AutoTokenizer.from_pretrained(
326
+ model_args.text_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
327
+ )
328
+ else:
329
+ raise ValueError(
330
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
331
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
332
+ )
333
+
334
+ model = FlaxHybridCLIP.from_text_vision_pretrained(
335
+ model_args.text_model_name_or_path,
336
+ model_args.vision_model_name_or_path,
337
+ seed=training_args.seed,
338
+ dtype=getattr(jnp, model_args.dtype),
339
+ text_from_pt=model_args.from_pt,
340
+ vision_from_pt=model_args.from_pt,
341
+ )
342
+ config = model.config
343
+ # set seed for torch dataloaders
344
+ set_seed(training_args.seed)
345
+
346
+ # Initialize torchvision transforms and jit them for faster processing
347
+ preprocess = Transform(config.vision_config.image_size)
348
+ preprocess = torch.jit.script(preprocess)
349
+
350
+ # Initialize the image-text dataset
351
+ train_dataset = ImageTextDataset(
352
+ data_args.data_dir,
353
+ data_args.train_file,
354
+ captions_per_image=2,
355
+ transform=preprocess,
356
+ )
357
+
358
+ eval_dataset = ImageTextDataset(
359
+ data_args.data_dir,
360
+ data_args.validation_file,
361
+ captions_per_image=1,
362
+ transform=preprocess,
363
+ )
364
+
365
+ # Store some constant
366
+ num_epochs = int(training_args.num_train_epochs)
367
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
368
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
369
+ steps_per_epoch = len(train_dataset) // train_batch_size
370
+ total_train_steps = steps_per_epoch * num_epochs
371
+
372
+ # Use collate function to tokenizer the text and convert the processed images to numpy
373
+ def collate_fn(examples):
374
+ pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()
375
+ captions = [example[1] for example in examples]
376
+ inputs = tokenizer(captions, max_length=data_args.max_seq_length, padding="max_length", return_tensors="np")
377
+
378
+ batch = {
379
+ "pixel_values": pixel_values,
380
+ "input_ids": inputs["input_ids"],
381
+ "attention_mask": inputs["attention_mask"],
382
+ }
383
+
384
+ return batch
385
+
386
+ # Create data loaders
387
+ train_loader = torch.utils.data.DataLoader(
388
+ train_dataset,
389
+ batch_size=train_batch_size,
390
+ shuffle=True,
391
+ num_workers=data_args.preprocessing_num_workers,
392
+ persistent_workers=True,
393
+ drop_last=True,
394
+ collate_fn=collate_fn,
395
+ )
396
+
397
+ eval_loader = torch.utils.data.DataLoader(
398
+ eval_dataset,
399
+ batch_size=eval_batch_size,
400
+ shuffle=False,
401
+ num_workers=data_args.preprocessing_num_workers,
402
+ persistent_workers=True,
403
+ drop_last=True,
404
+ collate_fn=collate_fn,
405
+ )
406
+
407
+ # Enable tensorboard only on the master node
408
+ if has_tensorboard and jax.process_index() == 0:
409
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix())
410
+
411
+ # Initialize our training
412
+ rng = jax.random.PRNGKey(training_args.seed)
413
+ rng, dropout_rng = jax.random.split(rng)
414
+
415
+ # Create learning rate schedule
416
+ linear_decay_lr_schedule_fn = create_learning_rate_fn(
417
+ len(train_dataset),
418
+ train_batch_size,
419
+ training_args.num_train_epochs,
420
+ training_args.warmup_steps,
421
+ training_args.learning_rate,
422
+ )
423
+
424
+ # create adam optimizer
425
+ adamw = optax.adamw(
426
+ learning_rate=linear_decay_lr_schedule_fn,
427
+ b1=training_args.adam_beta1,
428
+ b2=training_args.adam_beta2,
429
+ eps=training_args.adam_epsilon,
430
+ weight_decay=training_args.weight_decay,
431
+ )
432
+
433
+ # Setup train state
434
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
435
+
436
+ def cross_entropy(logits, axis):
437
+ logprobs = jax.nn.log_softmax(logits, axis=axis)
438
+ nll = jnp.diag(logprobs)
439
+ ce = -jnp.mean(nll)
440
+ return ce
441
+
442
+ def clip_loss(similarity):
443
+ loss = (cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)) / 2
444
+ return loss
445
+
446
+ # Define gradient update step fn
447
+ def train_step(state, batch):
448
+ dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
449
+
450
+ def compute_loss(params):
451
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
452
+ loss = clip_loss(logits)
453
+ return loss
454
+
455
+ grad_fn = jax.value_and_grad(compute_loss)
456
+ loss, grad = grad_fn(state.params)
457
+ grad = jax.lax.pmean(grad, "batch")
458
+
459
+ new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
460
+
461
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
462
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
463
+
464
+ return new_state, metrics
465
+
466
+ # Define eval fn
467
+ def eval_step(params, batch):
468
+ logits = model(**batch, params=params, train=False)[0]
469
+ loss = clip_loss(logits)
470
+
471
+ # summarize metrics
472
+ metrics = {"loss": loss}
473
+ metrics = jax.lax.pmean(metrics, axis_name="batch")
474
+ return metrics
475
+
476
+ # Create parallel version of the train and eval step
477
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
478
+ p_eval_step = jax.pmap(eval_step, "batch")
479
+
480
+ # Replicate the train state on each device
481
+ state = state.replicate()
482
+
483
+ logger.info("***** Running training *****")
484
+ logger.info(f" Num examples = {len(train_dataset)}")
485
+ logger.info(f" Num Epochs = {num_epochs}")
486
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
487
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
488
+ logger.info(f" Total optimization steps = {total_train_steps}")
489
+
490
+ train_time = 0
491
+ # Create sampling rng
492
+ rng, input_rng = jax.random.split(rng)
493
+
494
+ epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
495
+ for epoch in epochs:
496
+ # ======================== Training ================================
497
+ train_start = time.time()
498
+
499
+ # Create sampling rng
500
+ rng, input_rng = jax.random.split(rng)
501
+ train_metrics = []
502
+
503
+ steps_per_epoch = len(train_dataset) // train_batch_size
504
+ train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
505
+ # train
506
+ for batch in train_loader:
507
+ batch = shard(batch)
508
+ state, train_metric = p_train_step(state, batch)
509
+ train_metrics.append(train_metric)
510
+
511
+ train_step_progress_bar.update(1)
512
+
513
+ train_time += time.time() - train_start
514
+
515
+ train_metric = unreplicate(train_metric)
516
+
517
+ train_step_progress_bar.close()
518
+ epochs.write(
519
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
520
+ )
521
+
522
+ # ======================== Evaluating ==============================
523
+ eval_metrics = []
524
+ eval_steps = len(eval_dataset) // eval_batch_size
525
+ eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False)
526
+ for batch in eval_loader:
527
+ # Model forward
528
+ batch = shard(batch)
529
+ metrics = p_eval_step(state.params, batch)
530
+ eval_metrics.append(metrics)
531
+
532
+ eval_step_progress_bar.update(1)
533
+
534
+ # normalize eval metrics
535
+ eval_metrics = get_metrics(eval_metrics)
536
+
537
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
538
+
539
+ # Print metrics and update progress bar
540
+ eval_step_progress_bar.close()
541
+ desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
542
+ epochs.write(desc)
543
+ epochs.desc = desc
544
+
545
+ # Save metrics
546
+ if has_tensorboard and jax.process_index() == 0:
547
+ cur_step = epoch * (len(train_dataset) // train_batch_size)
548
+ write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
549
+
550
+ # save checkpoint after each epoch and push checkpoint to the hub
551
+ if jax.process_index() == 0:
552
+ params = jax.device_get(unreplicate(state.params))
553
+ model.save_pretrained(
554
+ training_args.output_dir,
555
+ params=params,
556
+ push_to_hub=training_args.push_to_hub,
557
+ commit_message=f"Saving weights and logs of epoch {epoch+1}",
558
+ )
559
+
560
+
561
+ if __name__ == "__main__":
562
+ main()