jykoh commited on
Commit
8b300d9
1 Parent(s): b99e4a0
Files changed (6) hide show
  1. .gitignore +1 -0
  2. README.md +0 -13
  3. app.py +107 -0
  4. fromage/__init__.py +0 -0
  5. fromage/models.py +658 -0
  6. fromage/utils.py +250 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .DS_Store
README.md DELETED
@@ -1,13 +0,0 @@
1
- ---
2
- title: Fromage
3
- emoji: 📚
4
- colorFrom: purple
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 3.18.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ import matplotlib.pyplot as plt
5
+
6
+ from fromage import models
7
+ from fromage import utils
8
+ import gradio as gr
9
+ import huggingface_hub
10
+ import tempfile
11
+
12
+
13
+ class FromageChatBot:
14
+ def __init__(self):
15
+ # Download model from HF Hub.
16
+ huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar')
17
+ huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='model_args.json')
18
+ huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='cc3m_embeddings.pkl')
19
+ self.model = models.load_fromage('./')
20
+ self.chat_history = ''
21
+ self.input_image = None
22
+
23
+
24
+ def reset(self):
25
+ self.chat_history = ""
26
+ self.input_image = None
27
+ return [], []
28
+
29
+
30
+ def upload_image(self, state, image_input):
31
+ state += [(f"![](/file={image_input.name})", ":)")]
32
+ self.input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB')
33
+ return state, state
34
+
35
+
36
+ def save_image_to_local(self, image: Image.Image):
37
+ # TODO(jykoh): Update so the url path is used, to prevent repeat saving.
38
+ filename = next(tempfile._get_candidate_names()) + '.png'
39
+ image.save(filename)
40
+ return filename
41
+
42
+
43
+ def generate_for_prompt(self, input_text, state, ret_scale_factor, num_ims, num_words, temp):
44
+ input_prompt = 'Q: ' + input_text + '\nA:'
45
+ self.chat_history += input_prompt
46
+
47
+ # If an image was uploaded, prepend it to the model.
48
+ model_inputs = None
49
+ if self.input_image is not None:
50
+ model_inputs = [self.input_image, self.chat_history]
51
+ else:
52
+ model_inputs = [self.chat_history]
53
+
54
+ model_outputs = self.model.generate_for_images_and_texts(model_inputs, max_num_rets=num_ims, num_words=num_words, ret_scale_factor=ret_scale_factor, temperature=temp)
55
+
56
+ im_names = []
57
+ response = ''
58
+ text_outputs = []
59
+ for output in model_outputs:
60
+ if type(output) == str:
61
+ text_outputs.append(output)
62
+ response += output
63
+ elif type(output) == list:
64
+ for image in output:
65
+ filename = self.save_image_to_local(image)
66
+ response += f'<img src="/file={filename}">'
67
+ elif type(output) == Image.Image:
68
+ filename = self.save_image_to_local(output)
69
+ response += f'<img src="/file={filename}">'
70
+
71
+ self.chat_history += ' '.join(text_output)
72
+ if self.chat_history[-1] != '\n':
73
+ self.chat_history += '\n'
74
+ self.input_image = None
75
+
76
+ state.append((input_text, response))
77
+ return state, state
78
+
79
+
80
+ def launch(self):
81
+ with gr.Blocks(css="#fromage-space {height:600px; overflow-y:auto;}") as demo:
82
+ chatbot = gr.Chatbot(elem_id="fromage-space")
83
+ gr_state = gr.State([])
84
+
85
+ with gr.Row():
86
+ with gr.Column(scale=0.85):
87
+ text_input = gr.Textbox(show_label=False, placeholder="Upload an image [optional]. Then enter a text prompt, and press enter!").style(container=False)
88
+ with gr.Column(scale=0.15, min_width=0):
89
+ image_btn = gr.UploadButton("Image", file_types=["image"])
90
+
91
+ with gr.Row():
92
+ with gr.Column(scale=0.20, min_width=0):
93
+ clear_btn = gr.Button("Clear")
94
+ ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
95
+ max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
96
+ gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
97
+ gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
98
+
99
+ text_input.submit(self.generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
100
+ image_btn.upload(self.upload_image, [gr_state, image_btn], [gr_state, chatbot])
101
+ clear_btn.click(self.reset, [], [gr_state, chatbot])
102
+
103
+ demo.launch(share=False, server_name="0.0.0.0")
104
+
105
+
106
+ chatbot = FromageChatBot()
107
+ chatbot.launch()
fromage/__init__.py ADDED
File without changes
fromage/models.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Optional, Tuple, Union
2
+ from collections import namedtuple
3
+ import json
4
+ import glob
5
+ import math
6
+ import numpy as np
7
+ import os
8
+ import torch
9
+ from torch import Tensor
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+ from functools import partial
14
+ import pickle as pkl
15
+ from PIL import Image, UnidentifiedImageError
16
+
17
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
18
+ from transformers import OPTForCausalLM, GPT2Tokenizer
19
+ from transformers import CLIPVisionModel, CLIPVisionConfig
20
+
21
+ from fromage import utils
22
+
23
+
24
+ class FrozenArgs:
25
+ freeze_lm: bool = True
26
+ freeze_vm: bool = True
27
+ opt_version: str = 'facebook/opt-6.7b'
28
+ visual_encoder: str = 'openai/clip-vit-large-patch14'
29
+ n_visual_tokens: int = 1
30
+ image_embed_dropout_prob: float = 0.0
31
+ task: str = 'captioning'
32
+ shared_emb_dim: Optional[int] = 256
33
+ text_emb_layers: List[int] = [-1]
34
+ retrieval_token_idx: int = 0
35
+
36
+
37
+ class FromageModel(nn.Module):
38
+ def __init__(self, tokenizer, args: FrozenArgs = FrozenArgs()):
39
+ super().__init__()
40
+ self.tokenizer = tokenizer
41
+ self.feature_extractor = utils.get_feature_extractor_for_model(args.visual_encoder, train=False)
42
+ self.image_token = self.tokenizer.cls_token_id
43
+ assert args.text_emb_layers != set(args.text_emb_layers), 'text_emb_layers not unique'
44
+ self.args = args
45
+
46
+ opt_version = args.opt_version
47
+ visual_encoder = args.visual_encoder
48
+ n_visual_tokens = args.n_visual_tokens
49
+ print(f"Using {opt_version} for the language model.")
50
+ print(f"Using {visual_encoder} for the visual model with {n_visual_tokens} visual tokens.")
51
+
52
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
53
+
54
+ if 'facebook/opt' in opt_version:
55
+ self.lm = OPTForCausalLM.from_pretrained(opt_version)
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ self.opt_version = opt_version
60
+
61
+ if self.args.freeze_lm:
62
+ self.lm.eval()
63
+ print("Freezing the LM.")
64
+ for param in self.lm.parameters():
65
+ param.requires_grad = False
66
+ else:
67
+ self.lm.train()
68
+
69
+ self.retrieval_token_idx = args.retrieval_token_idx
70
+ print(f'Initializing embedding for the retrieval token [RET] (id = {self.retrieval_token_idx}).')
71
+ self.lm.resize_token_embeddings(len(tokenizer))
72
+
73
+ self.input_embeddings = self.lm.get_input_embeddings()
74
+
75
+ print("Restoring pretrained weights for the visual model.")
76
+ if 'clip' in visual_encoder:
77
+ self.visual_model = CLIPVisionModel.from_pretrained(visual_encoder)
78
+ else:
79
+ self.visual_model = AutoModel.from_pretrained(visual_encoder)
80
+
81
+ if 'clip' in visual_encoder:
82
+ hidden_size = self.visual_model.config.hidden_size
83
+ else:
84
+ raise NotImplementedError
85
+
86
+ if self.args.freeze_vm:
87
+ print("Freezing the VM.")
88
+ self.visual_model.eval()
89
+ for param in self.visual_model.parameters():
90
+ param.requires_grad = False
91
+ else:
92
+ self.visual_model.train()
93
+
94
+ self.visual_model_name = visual_encoder
95
+
96
+ embedding_dim = self.input_embeddings.embedding_dim * self.args.n_visual_tokens
97
+ self.text_hidden_fcs = nn.ModuleList([])
98
+ if self.args.shared_emb_dim is None:
99
+ if len(self.args.text_emb_layers) == 1:
100
+ if (self.args.text_emb_layers[0] in [-1, self.lm.config.num_hidden_layers]) and ('bert' not in opt_version):
101
+ out_dim = self.lm.config.word_embed_proj_dim
102
+ else:
103
+ out_dim = self.lm.config.hidden_size
104
+ else:
105
+ if (-1 in self.args.text_emb_layers) or (self.lm.config.num_hidden_layers in self.args.text_emb_layers) \
106
+ and (self.lm.config.word_embed_proj_dim != self.lm.config.hidden_size):
107
+ raise ValueError('No projection dim specified but model uses last output layer and an intermediate one (which have different dims).')
108
+ else:
109
+ out_dim = self.lm.config.hidden_size
110
+ else:
111
+ out_dim = self.args.shared_emb_dim
112
+
113
+ for layer_idx in self.args.text_emb_layers:
114
+ if (layer_idx == -1 or layer_idx == self.lm.config.num_hidden_layers) and ('bert' not in opt_version):
115
+ in_dim = self.lm.config.word_embed_proj_dim
116
+
117
+ text_fc = [nn.Linear(in_dim, out_dim), nn.Dropout(self.args.text_embed_dropout_prob)]
118
+ self.text_hidden_fcs.append(nn.Sequential(*text_fc))
119
+
120
+ elif layer_idx < self.lm.config.num_hidden_layers:
121
+ text_fc = [nn.Linear(self.lm.config.hidden_size, out_dim), nn.Dropout(self.args.text_embed_dropout_prob)]
122
+ self.text_hidden_fcs.append(nn.Sequential(*text_fc))
123
+ else:
124
+ raise ValueError(f'Embedding of layer {layer_idx} was requested but model only has {self.lm.config.num_hidden_layers} layers.')
125
+
126
+ self.visual_embeddings = nn.Linear(hidden_size, embedding_dim)
127
+ self.visual_fc = nn.Linear(hidden_size, out_dim)
128
+
129
+ self.image_dropout = nn.Dropout(self.args.image_embed_dropout_prob)
130
+
131
+
132
+ def get_visual_embs(self, pixel_values: torch.FloatTensor, mode: str = 'captioning'):
133
+ if mode not in ['captioning', 'retrieval']:
134
+ raise ValueError(f'mode should be one of ["caption", "retrieval"], got {mode} instead.')
135
+
136
+ # Extract visual embeddings from the vision encoder.
137
+ if 'clip' in self.visual_model_name:
138
+ outputs = self.visual_model(pixel_values)
139
+ encoder_outputs = outputs.pooler_output
140
+ else:
141
+ raise NotImplementedError
142
+
143
+ # Use the correct fc based on function argument.
144
+ if mode == 'captioning':
145
+ visual_embs = self.visual_embeddings(encoder_outputs) # (2, D * n_visual_tokens)
146
+ visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], self.args.n_visual_tokens, -1))
147
+ elif mode == 'retrieval':
148
+ visual_embs = self.visual_fc(encoder_outputs) # (2, D * n_visual_tokens)
149
+ visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], 1, -1))
150
+ else:
151
+ raise NotImplementedError
152
+
153
+ visual_embs = self.image_dropout(visual_embs)
154
+ return visual_embs
155
+
156
+
157
+ def train(self, mode=True):
158
+ super(FromageModel, self).train(mode=mode)
159
+ # Overwrite train() to ensure Frozen models remain frozen.
160
+ if self.args.freeze_lm:
161
+ self.lm.eval()
162
+ if self.args.freeze_vm:
163
+ self.visual_model.eval()
164
+
165
+
166
+ def forward(
167
+ self,
168
+ pixel_values: torch.FloatTensor,
169
+ labels: torch.LongTensor,
170
+ caption_len: torch.LongTensor,
171
+ mode: str = 'captioning',
172
+ concat_captions: bool = False,
173
+ input_prefix: Optional[str] = None,
174
+ inference: bool = False,
175
+ ):
176
+ visual_embs = self.get_visual_embs(pixel_values, mode)
177
+
178
+ batch_size, vis_seq_len, _ = visual_embs.shape # vis_seq_len = n_visual_tokens
179
+ if labels is not None:
180
+ assert labels.shape[0] == batch_size, (visual_embs.shape, labels.shape)
181
+
182
+ input_embs = self.input_embeddings(labels) # (N, T, D)
183
+
184
+ last_embedding_idx = caption_len - 1 # -1 to retrieve the token before the eos token
185
+
186
+ if input_prefix is not None:
187
+ prompt_ids = self.tokenizer(input_prefix, add_special_tokens=False, return_tensors="pt").input_ids
188
+ prompt_ids = prompt_ids.to(visual_embs.device)
189
+ prompt_embs = self.input_embeddings(prompt_ids)
190
+ prompt_embs = prompt_embs.repeat(batch_size, 1, 1)
191
+ assert prompt_embs.shape[0] == batch_size, prompt_embs.shape
192
+ assert prompt_embs.shape[2] == input_embs.shape[2], prompt_embs.shape
193
+ assert len(prompt_embs.shape) == 3, prompt_embs.shape
194
+
195
+ if mode == 'captioning':
196
+ # Concat to text embeddings.
197
+ condition_seq_len = 0
198
+ if input_prefix is None:
199
+ # Just add visual embeddings.
200
+ input_embs = torch.cat([visual_embs, input_embs], axis=1)
201
+ last_embedding_idx += vis_seq_len
202
+ condition_seq_len += vis_seq_len
203
+ full_labels = torch.zeros(visual_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100
204
+ else:
205
+ # Add visual and prompt embeddings.
206
+ prefix_embs = torch.cat([visual_embs, prompt_embs], axis=1)
207
+ input_embs = torch.cat([prefix_embs, input_embs], axis=1)
208
+
209
+ last_embedding_idx += prefix_embs.shape[1]
210
+ condition_seq_len += prefix_embs.shape[1]
211
+ full_labels = torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100
212
+
213
+ # Mask out embedding tokens in the labels.
214
+ full_labels = torch.cat([full_labels, labels], axis=1)
215
+
216
+ pad_idx = []
217
+
218
+ for label in full_labels:
219
+ for k, token in enumerate(label):
220
+ # Mask out retrieval token if it exists.
221
+ if token in [self.tokenizer.pad_token_id, self.retrieval_token_idx]:
222
+ label[k:] = -100
223
+ pad_idx.append(k)
224
+ break
225
+ if k == len(label) - 1: # No padding found.
226
+ pad_idx.append(k + 1)
227
+ assert len(pad_idx) == batch_size, (len(pad_idx), batch_size)
228
+
229
+ bs, seq_len, embs_dim = input_embs.shape
230
+ if concat_captions:
231
+ assert len(input_embs.shape) == 3, input_embs
232
+ assert len(full_labels.shape) == 2, full_labels
233
+ assert batch_size % 2 == 0
234
+ all_concat_input_embs = []
235
+ all_concat_labels = []
236
+
237
+ # Rearrange embeddings and labels (and their padding) to concatenate captions.
238
+ for i in range(batch_size // 2):
239
+ first_idx = i * 2
240
+ second_idx = first_idx + 1
241
+ first_emb = input_embs[first_idx, :pad_idx[first_idx], :]
242
+ first_labels = full_labels[first_idx, :pad_idx[first_idx]]
243
+ first_padding = input_embs[first_idx, pad_idx[first_idx]:, :]
244
+ first_labels_padding = full_labels[first_idx, pad_idx[first_idx]:]
245
+
246
+ second_emb = input_embs[second_idx, :pad_idx[second_idx], :]
247
+ second_labels = full_labels[second_idx, :pad_idx[second_idx]]
248
+ second_padding = input_embs[second_idx, pad_idx[second_idx]:, :]
249
+ second_labels_padding = full_labels[second_idx, pad_idx[second_idx]:]
250
+
251
+ assert torch.all(first_labels_padding == -100), first_labels_padding
252
+ assert torch.all(second_labels_padding == -100), second_labels_padding
253
+ concat_input_embs = torch.cat([first_emb, second_emb, first_padding, second_padding], axis=0) # (T*2, 768)
254
+ concat_labels = torch.cat([first_labels, second_labels, first_labels_padding, second_labels_padding], axis=0) # (T*2, 768)
255
+ all_concat_input_embs.append(concat_input_embs)
256
+ all_concat_labels.append(concat_labels)
257
+
258
+ # Pad to max length.
259
+ input_embs = torch.stack(all_concat_input_embs, axis=0) # (N/2, T*2, 768)
260
+ full_labels = torch.stack(all_concat_labels, axis=0) # (N/2, T*2, 768)
261
+ assert input_embs.shape == (bs // 2, seq_len * 2, embs_dim), input_embs.shape
262
+ assert full_labels.shape == (bs // 2, seq_len * 2), full_labels.shape
263
+
264
+ output = self.lm(inputs_embeds=input_embs,
265
+ labels=full_labels,
266
+ output_hidden_states=True)
267
+ elif mode == 'retrieval':
268
+ full_labels = torch.clone(labels)
269
+ if input_prefix is not None:
270
+ print(f'Adding prefix "{input_prefix}" to retrieval.')
271
+ # Add prompt embeddings.
272
+ prefix_embs = prompt_embs
273
+ input_embs = torch.cat([prefix_embs, input_embs], axis=1)
274
+ last_embedding_idx += prefix_embs.shape[1]
275
+ full_labels = torch.cat([
276
+ torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(labels.device) - 100,
277
+ full_labels
278
+ ], axis=1)
279
+
280
+ pad_idx = []
281
+ for label in full_labels:
282
+ for k, token in enumerate(label):
283
+ if token == self.tokenizer.pad_token_id:
284
+ label[k:] = -100
285
+ pad_idx.append(k)
286
+ break
287
+ if k == len(label) - 1: # No padding found.
288
+ pad_idx.append(k + 1)
289
+ assert len(pad_idx) == batch_size, (len(pad_idx), batch_size)
290
+
291
+ output = self.lm(inputs_embeds=input_embs,
292
+ labels=full_labels,
293
+ output_hidden_states=True)
294
+ else:
295
+ raise NotImplementedError
296
+
297
+ last_embedding = None
298
+ last_output_logit = None
299
+ hidden_states = []
300
+
301
+ if mode == 'retrieval':
302
+ if self.args.shared_emb_dim is not None:
303
+ for idx, fc_layer in zip(self.args.text_emb_layers, self.text_hidden_fcs):
304
+ hidden_states.append(fc_layer(output.hidden_states[idx])) # (N, seq_len, 2048)
305
+ else:
306
+ for idx in self.args.text_emb_layers:
307
+ hidden_states.append(output.hidden_states[idx])
308
+
309
+ # Add hidden states together.
310
+ last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
311
+
312
+ if not concat_captions:
313
+ last_embedding = torch.stack([last_hidden_state[i, last_embedding_idx[i], :] for i in range(batch_size)], axis=0) # (N, D)
314
+ last_output_logit = torch.stack([output.logits[i, last_embedding_idx[i] - 1, :] for i in range(batch_size)], axis=0) # (N, D)
315
+ else:
316
+ # Concatenate two captioning examples together.
317
+ all_last_embedding = []
318
+ all_last_output_logit = []
319
+ for i in range(batch_size // 2):
320
+ first_last_embedding_idx, second_last_embedding_idx = all_last_embedding_idx[i]
321
+ first_last_embedding = last_hidden_state[i, first_last_embedding_idx, :] # (N, D)
322
+ first_last_output_logit = output.logits[i, first_last_embedding_idx - 1, :] # (N, D)
323
+ second_last_embedding = last_hidden_state[i, second_last_embedding_idx, :] # (N, D)
324
+ second_last_output_logit = output.logits[i, second_last_embedding_idx - 1, :] # (N, D)
325
+ all_last_embedding.append(first_last_embedding)
326
+ all_last_embedding.append(second_last_embedding)
327
+ all_last_output_logit.append(first_last_output_logit)
328
+ all_last_output_logit.append(second_last_output_logit)
329
+
330
+ last_embedding = torch.stack(all_last_embedding)
331
+ last_output_logit = torch.stack(all_last_output_logit)
332
+
333
+ # Compute retrieval loss.
334
+ assert visual_embs.shape[1] == 1, visual_embs.shape
335
+ visual_embs = visual_embs[:, 0, :]
336
+ visual_embs = visual_embs / visual_embs.norm(dim=1, keepdim=True)
337
+ last_embedding = last_embedding / last_embedding.norm(dim=1, keepdim=True)
338
+
339
+ # cosine similarity as logits
340
+ logit_scale = self.logit_scale.exp()
341
+ visual_embs = logit_scale * visual_embs
342
+ elif mode == 'captioning':
343
+ pass
344
+ else:
345
+ raise NotImplementedError
346
+
347
+ return output, full_labels, last_embedding, last_output_logit, visual_embs
348
+
349
+ def generate(self, embeddings = torch.FloatTensor, max_len: int = 32,
350
+ temperature: float = 0.0, top_p: float = 1.0, min_word_tokens: int = 0,
351
+ ret_scale_factor: float = 1.0, filter_value: float = -float('Inf')):
352
+ """Runs greedy decoding and returns generated captions.
353
+
354
+ Args:
355
+ embeddings: Input condition that the model uses for autoregressive generation.
356
+ max_len: Maximum number of tokens to generate.
357
+ temperature: Used to modulate logit distribution.
358
+ top_p: If set to < 1, the smallest set of tokens with highest probabilities that add up to top_p or higher are kept for generation.
359
+ min_word_tokens: Minimum number of words to generate before allowing a [RET] output.
360
+ ret_scale_factor: Proportion to scale [RET] token logits by. A higher value may increase the probability of the model generating [RET] outputs.
361
+ filter_value: Value to assign to tokens that should never be generated.
362
+ Outputs:
363
+ out: (N, T) int32 sequence of output tokens.
364
+ output_embeddings: (N, T, 256) sequence of text output embeddings.
365
+ """
366
+ self.lm.eval()
367
+
368
+ with torch.no_grad(): # no tracking history
369
+ batch_size, s, _ = embeddings.shape
370
+ # init output with image tokens
371
+ out = None
372
+ past_key_values = None
373
+ output_embeddings = []
374
+ output_logits = []
375
+
376
+ for i in range(max_len):
377
+ if 'opt' in self.opt_version:
378
+ output = self.lm(inputs_embeds=embeddings, use_cache=False, output_hidden_states=True)
379
+ else:
380
+ if i == 0:
381
+ output = self.lm(inputs_embeds=embeddings, use_cache=True, past_key_values=None, output_hidden_states=True)
382
+ else:
383
+ output = self.lm(input_ids=out[:, -1:], use_cache=True, past_key_values=past_key_values, output_hidden_states=True)
384
+
385
+ # Collect and sum the hidden states.
386
+ hidden_states = []
387
+ if self.args.shared_emb_dim is not None:
388
+ for idx, fc_layer in zip(self.args.text_emb_layers, self.text_hidden_fcs):
389
+ hidden_states.append(fc_layer(output.hidden_states[idx])) # (N, seq_len, 2048)
390
+ else:
391
+ for idx in self.args.text_emb_layers:
392
+ hidden_states.append(output.hidden_states[idx])
393
+ # Add hidden states together.
394
+ last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) # (N, T, 256)
395
+ last_embedding = last_hidden_state / last_hidden_state.norm(dim=-1, keepdim=True)
396
+ output_embeddings.append(last_embedding)
397
+
398
+ logits = output.logits[:, -1, :] # (N, vocab_size)
399
+ if top_p == 1.0:
400
+ logits = logits.cpu()
401
+ output_logits.append(logits)
402
+
403
+ if self.retrieval_token_idx != -1 and self.retrieval_token_idx is not None:
404
+ if i < min_word_tokens:
405
+ # Eliminate probability of generating [RET] if this is earlier than min_word_tokens.
406
+ logits[:, self.retrieval_token_idx] = filter_value
407
+ else:
408
+ # Multiply by scaling factor.
409
+ logits[:, self.retrieval_token_idx] = logits[:, self.retrieval_token_idx] * ret_scale_factor
410
+
411
+ past_key_values = output.past_key_values
412
+
413
+ if temperature == 0.0:
414
+ if top_p != 1.0:
415
+ raise ValueError('top_p cannot be set if temperature is 0 (greedy decoding).')
416
+ next_token = torch.argmax(logits, keepdim=True, dim=-1) # (N, 1)
417
+ else:
418
+ logits = logits / temperature
419
+
420
+ # Apply top-p filtering.
421
+ if top_p < 1.0:
422
+ assert top_p > 0, f'top_p should be above 0, got {top_p} instead.'
423
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True) # (N, D) and (N, D)
424
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # (N, D)
425
+
426
+ # Remove tokens with cumulative probability above the threshold
427
+ sorted_indices_to_remove = cumulative_probs > top_p
428
+ # Shift the indices to the right to keep also the first token above the threshold
429
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
430
+ sorted_indices_to_remove[..., 0] = 0
431
+
432
+ for j in range(sorted_indices.shape[0]):
433
+ indices_to_remove = sorted_indices[j, sorted_indices_to_remove[j, :]]
434
+ logits[j, indices_to_remove] = filter_value
435
+
436
+ token_weights = logits.exp() # (N, vocab_size)
437
+ next_token = torch.multinomial(token_weights, 1) # (N, 1)
438
+
439
+ next_token = next_token.long().to(embeddings.device)
440
+ if out is not None:
441
+ out = torch.cat([out, next_token], dim=-1)
442
+ else:
443
+ out = next_token
444
+
445
+ if 'opt' in self.opt_version:
446
+ next_embedding = self.input_embeddings(next_token)
447
+ embeddings = torch.cat([embeddings, next_embedding], dim=1)
448
+ elif (self.tokenizer.eos_token_id and (next_token == self.tokenizer.eos_token_id).all()):
449
+ # End of generation.
450
+ break
451
+
452
+ return out, output_embeddings, output_logits
453
+
454
+
455
+ class Fromage(nn.Module):
456
+ def __init__(self, tokenizer, model_args: Optional[FrozenArgs] = None,
457
+ path_array: Optional[List[str]] = None, emb_matrix: Optional[torch.tensor] = None):
458
+ super().__init__()
459
+ self.model = FromageModel(tokenizer, model_args)
460
+ self.path_array = path_array
461
+ self.emb_matrix = emb_matrix
462
+
463
+ def __call__(self, images: Tensor, tgt_tokens: Optional[Tensor] = None, caption_len: Optional[Tensor] = None,
464
+ generate: bool = False, num_words: int = 32, temperature: float = 1.0, top_p: float = 1.0,
465
+ ret_scale_factor: float = 1.0, min_word_tokens: int = 0,
466
+ mode: str = 'captioning', concat_captions: bool = False,
467
+ input_prefix: Optional[str] = None, inference: bool = False) -> Tensor:
468
+ if generate:
469
+ return self.model.generate(images, num_words, temperature=temperature, top_p=top_p,
470
+ min_word_tokens=min_word_tokens, ret_scale_factor=ret_scale_factor)
471
+ else:
472
+ output = self.model(
473
+ pixel_values = images,
474
+ labels = tgt_tokens,
475
+ caption_len = caption_len,
476
+ mode = mode,
477
+ concat_captions = concat_captions,
478
+ input_prefix = input_prefix,
479
+ inference = inference)
480
+ return output
481
+
482
+ def generate_for_images_and_texts(
483
+ self, prompts: List, num_words: int = 0, ret_scale_factor: float = 1.0, top_p: float = 1.0, temperature: float = 0.0,
484
+ max_num_rets: int = 1):
485
+ """
486
+ Encode prompts into embeddings.
487
+
488
+ Args:
489
+ prompts: List of interleaved PIL.Image.Image and strings representing input to the model.
490
+ num_words: Maximum number of words to generate for. If num_words = 0, the model will run its forward pass and return the outputs.
491
+ ret_scale_factor: Proportion to scale [RET] token logits by. A higher value may increase the probability of the model generating [RET] outputs.
492
+ top_p: If set to < 1, the smallest set of tokens with highest probabilities that add up to top_p or higher are kept for generation.
493
+ temperature: Used to modulate logit distribution.
494
+ max_num_rets: Maximum number of images to return in one generation pass.
495
+ Returns:
496
+ return_outputs: List consisting of either str or List[PIL.Image.Image] objects, representing image-text interleaved model outputs.
497
+ """
498
+ input_embs = []
499
+ input_ids = []
500
+ add_bos = True
501
+
502
+ for i, p in enumerate(prompts):
503
+ if type(p) == Image.Image:
504
+ # Encode as image.
505
+ pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, p)
506
+ pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype)
507
+ pixel_values = pixel_values[None, ...]
508
+
509
+ visual_embs = self.model.get_visual_embs(pixel_values, mode='captioning') # (1, n_visual_tokens, D)
510
+ input_embs.append(visual_embs)
511
+ elif type(p) == str:
512
+ text_ids = self.model.tokenizer(p, add_special_tokens=True, return_tensors="pt").input_ids.to(self.model.logit_scale.device)
513
+ if not add_bos:
514
+ # Remove <bos> tag.
515
+ text_ids = text_ids[:, 1:]
516
+ else:
517
+ # Only add <bos> once.
518
+ add_bos = False
519
+
520
+ text_embs = self.model.input_embeddings(text_ids) # (1, T, D)
521
+ input_embs.append(text_embs)
522
+ input_ids.append(text_ids)
523
+ else:
524
+ raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.')
525
+ input_embs = torch.cat(input_embs, dim=1)
526
+ input_ids = torch.cat(input_ids, dim=1)
527
+
528
+ if num_words == 0:
529
+ generated_ids = input_ids
530
+ outputs = self.model.lm(inputs_embeds=input_embs, use_cache=False, output_hidden_states=True)
531
+ # Map outputs to embeddings, so we can retrieve embeddings from the [RET] tokens.
532
+ out = []
533
+ for x, fc in zip(self.model.args.text_emb_layers, self.model.text_hidden_fcs):
534
+ out.append(fc(outputs.hidden_states[x]))
535
+ embeddings = torch.stack(out, dim=-1).sum(dim=-1)
536
+ embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (N, T, 256)
537
+ elif num_words > 0:
538
+ generated_ids, generated_embeddings, _ = self.model.generate(input_embs, num_words,
539
+ temperature=temperature, top_p=top_p, ret_scale_factor=ret_scale_factor)
540
+ embeddings = generated_embeddings[-1][:, input_embs.shape[1]:]
541
+
542
+ # Truncate to newline.
543
+ newline_token_id = self.model.tokenizer('\n', add_special_tokens=False).input_ids[0]
544
+ trunc_idx = 0
545
+ for j in range(generated_ids.shape[1]):
546
+ if generated_ids[0, j] == newline_token_id:
547
+ trunc_idx = j
548
+ break
549
+ if trunc_idx > 0:
550
+ generated_ids = generated_ids[:, :trunc_idx]
551
+ embeddings = embeddings[:, :trunc_idx]
552
+ else:
553
+ raise ValueError
554
+
555
+ # Save outputs as an interleaved list.
556
+ return_outputs = []
557
+ # Find up to max_num_rets [RET] tokens, and their corresponding scores.
558
+ all_ret_idx = [i for i, x in enumerate(generated_ids[0, :] == self.model.retrieval_token_idx) if x][:max_num_rets]
559
+ seen_image_idx = [] # Avoid showing the same image multiple times.
560
+
561
+ last_ret_idx = 0
562
+ if len(all_ret_idx) == 0:
563
+ # No [RET] tokens.
564
+ caption = self.model.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
565
+ return_outputs.append(utils.truncate_caption(caption))
566
+ else:
567
+ for ret_idx in all_ret_idx:
568
+ ret_emb = embeddings[:, ret_idx, :]
569
+ scores = self.emb_matrix @ ret_emb.T
570
+
571
+ # Downweight seen images.
572
+ for seen_idx in seen_image_idx:
573
+ scores[seen_idx, :] -= 1000
574
+
575
+ # Get the top 3 images for each image.
576
+ _, top_image_idx = scores.squeeze().topk(3)
577
+ image_outputs = []
578
+ for img_idx in top_image_idx:
579
+ # Find the first image that does not error out.
580
+ try:
581
+ seen_image_idx.append(img_idx)
582
+ img = utils.get_image_from_url(self.path_array[img_idx])
583
+ image_outputs.append(img)
584
+ if len(image_outputs) == max_num_rets:
585
+ break
586
+ except UnidentifiedImageError:
587
+ pass
588
+
589
+ caption = self.model.tokenizer.batch_decode(generated_ids[:, last_ret_idx:ret_idx], skip_special_tokens=True)[0]
590
+ last_ret_idx = ret_idx + 1
591
+ return_outputs.append(utils.truncate_caption(caption) + ' [RET]')
592
+ return_outputs.append(image_outputs)
593
+
594
+ return return_outputs
595
+
596
+
597
+ def load_fromage(model_dir: str) -> Fromage:
598
+ model_args_path = os.path.join(model_dir, 'model_args.json')
599
+ model_ckpt_path = os.path.join(model_dir, 'pretrained_ckpt.pth.tar')
600
+ embs_paths = [s for s in glob.glob(os.path.join(model_dir, 'cc3m_embeddings*.pkl'))]
601
+
602
+ if not os.path.exists(model_args_path):
603
+ raise ValueError(f'model_args.json does not exist in {model_dir}.')
604
+ if not os.path.exists(model_ckpt_path):
605
+ raise ValueError(f'pretrained_ckpt.pth.tar does not exist in {model_dir}.')
606
+ if len(embs_paths) == 0:
607
+ raise ValueError(f'cc3m_embeddings_*.pkl files do not exist in {model_dir}.')
608
+
609
+ # Load embeddings.
610
+ # Construct embedding matrix for nearest neighbor lookup.
611
+ path_array = []
612
+ emb_matrix = []
613
+
614
+ # These were precomputed for all CC3M images with `model.get_visual_embs(image, mode='retrieval')`.
615
+ for p in embs_paths:
616
+ with open(p, 'rb') as wf:
617
+ train_embs_data = pkl.load(wf)
618
+ path_array.extend(train_embs_data['paths'])
619
+ emb_matrix.append(train_embs_data['embeddings'])
620
+ emb_matrix = np.concatenate(emb_matrix, axis=0)
621
+
622
+ # Number of paths should be equal to number of embeddings.
623
+ assert len(path_array) == emb_matrix.shape[0], (len(path_array), emb_matrix.shape[0])
624
+
625
+ with open(model_args_path, 'r') as f:
626
+ model_kwargs = json.load(f)
627
+
628
+ # Initialize tokenizer.
629
+ tokenizer = GPT2Tokenizer.from_pretrained(model_kwargs['opt_version'])
630
+ tokenizer.pad_token = tokenizer.eos_token
631
+ # Add special tokens to the model to enable [RET].
632
+ tokenizer.add_special_tokens({"cls_token": "<|image|>"})
633
+ tokenizer.add_tokens('[RET]')
634
+ ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
635
+ assert len(ret_token_idx) == 1, ret_token_idx
636
+ model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
637
+ args = namedtuple('args', model_kwargs)(**model_kwargs)
638
+
639
+ # Initialize model for inference.
640
+ model = Fromage(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix)
641
+ model = model.eval()
642
+ model = model.bfloat16()
643
+ model = model.cuda()
644
+
645
+ # Load pretrained linear mappings and [RET] embeddings.
646
+ checkpoint = torch.load(model_ckpt_path)
647
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
648
+ with torch.no_grad():
649
+ model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].cpu().detach())
650
+
651
+ logit_scale = model.model.logit_scale.exp()
652
+ emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)
653
+ emb_matrix = emb_matrix / emb_matrix.norm(dim=1, keepdim=True)
654
+ emb_matrix = logit_scale * emb_matrix
655
+ model.emb_matrix = emb_matrix
656
+
657
+ return model
658
+
fromage/utils.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import subprocess
3
+ import sys
4
+ import shutil
5
+ import torch
6
+ import torch.distributed as dist
7
+ from torchvision.transforms import functional as F
8
+ from torchvision import transforms as T
9
+ from transformers import AutoFeatureExtractor
10
+ from PIL import Image, ImageDraw, ImageFont, ImageOps
11
+ import requests
12
+ from io import BytesIO
13
+
14
+ import random
15
+
16
+
17
+ def dump_git_status(out_file=sys.stdout, exclude_file_patterns=['*.ipynb', '*.th', '*.sh', '*.txt', '*.json']):
18
+ """Logs git status to stdout."""
19
+ subprocess.call('git rev-parse HEAD', shell=True, stdout=out_file)
20
+ subprocess.call('echo', shell=True, stdout=out_file)
21
+ exclude_string = ''
22
+ subprocess.call('git --no-pager diff -- . {}'.format(exclude_string), shell=True, stdout=out_file)
23
+
24
+
25
+ def get_image_from_url(url: str):
26
+ response = requests.get(url)
27
+ img = Image.open(BytesIO(response.content))
28
+ img = img.resize((224, 224))
29
+ img = img.convert('RGB')
30
+ return img
31
+
32
+
33
+ def truncate_caption(caption: str) -> str:
34
+ """Truncate captions at periods and newlines."""
35
+ trunc_index = caption.find('\n') + 1
36
+ if trunc_index <= 0:
37
+ trunc_index = caption.find('.') + 1
38
+ caption = caption[:trunc_index]
39
+ return caption
40
+
41
+
42
+ def pad_to_size(x, size=256):
43
+ delta_w = size - x.size[0]
44
+ delta_h = size - x.size[1]
45
+ padding = (
46
+ delta_w // 2,
47
+ delta_h // 2,
48
+ delta_w - (delta_w // 2),
49
+ delta_h - (delta_h // 2),
50
+ )
51
+ new_im = ImageOps.expand(x, padding)
52
+ return new_im
53
+
54
+
55
+ class RandCropResize(object):
56
+
57
+ """
58
+ Randomly crops, then randomly resizes, then randomly crops again, an image. Mirroring the augmentations from https://arxiv.org/abs/2102.12092
59
+ """
60
+
61
+ def __init__(self, target_size):
62
+ self.target_size = target_size
63
+
64
+ def __call__(self, img):
65
+ img = pad_to_size(img, self.target_size)
66
+ d_min = min(img.size)
67
+ img = T.RandomCrop(size=d_min)(img)
68
+ t_min = min(d_min, round(9 / 8 * self.target_size))
69
+ t_max = min(d_min, round(12 / 8 * self.target_size))
70
+ t = random.randint(t_min, t_max + 1)
71
+ img = T.Resize(t)(img)
72
+ if min(img.size) < 256:
73
+ img = T.Resize(256)(img)
74
+ return T.RandomCrop(size=self.target_size)(img)
75
+
76
+
77
+ class SquarePad(object):
78
+ """Pads image to square.
79
+ From https://discuss.pytorch.org/t/how-to-resize-and-pad-in-a-torchvision-transforms-compose/71850/9
80
+ """
81
+ def __call__(self, image):
82
+ max_wh = max(image.size)
83
+ p_left, p_top = [(max_wh - s) // 2 for s in image.size]
84
+ p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])]
85
+ padding = (p_left, p_top, p_right, p_bottom)
86
+ return F.pad(image, padding, 0, 'constant')
87
+
88
+
89
+ def create_image_of_text(text: str, width: int = 224, nrows: int = 2, color=(255, 255, 255), font=None) -> torch.Tensor:
90
+ """Creates a (3, nrows * 14, width) image of text.
91
+
92
+ Returns:
93
+ cap_img: (3, 14 * nrows, width) image of wrapped text.
94
+ """
95
+ height = 12
96
+ padding = 5
97
+ effective_width = width - 2 * padding
98
+ # Create a black image to draw text on.
99
+ cap_img = Image.new('RGB', (effective_width * nrows, height), color = (0, 0, 0))
100
+ draw = ImageDraw.Draw(cap_img)
101
+ draw.text((0, 0), text, color, font=font or ImageFont.load_default())
102
+ cap_img = F.convert_image_dtype(F.pil_to_tensor(cap_img), torch.float32) # (3, height, W * nrows)
103
+ cap_img = torch.split(cap_img, effective_width, dim=-1) # List of nrow elements of shape (3, height, W)
104
+ cap_img = torch.cat(cap_img, dim=1) # (3, height * nrows, W)
105
+ # Add zero padding.
106
+ cap_img = torch.nn.functional.pad(cap_img, [padding, padding, 0, padding])
107
+ return cap_img
108
+
109
+
110
+ def get_feature_extractor_for_model(model_name: str, image_size: int = 224, train: bool = True):
111
+ print(f'Using HuggingFace AutoFeatureExtractor for {model_name}.')
112
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
113
+ return feature_extractor
114
+
115
+
116
+ def get_pixel_values_for_model(feature_extractor, img):
117
+ pixel_values = feature_extractor(
118
+ img.convert('RGB'),
119
+ return_tensors="pt").pixel_values[0, ...] # (3, H, W)
120
+ return pixel_values
121
+
122
+
123
+ def save_checkpoint(state, is_best, filename='checkpoint'):
124
+ torch.save(state, filename + '.pth.tar')
125
+ if is_best:
126
+ shutil.copyfile(filename + '.pth.tar', filename + '_best.pth.tar')
127
+
128
+
129
+ def accuracy(output, target, padding, topk=(1,)):
130
+ """Computes the accuracy over the k top predictions for the specified values of k"""
131
+ with torch.no_grad():
132
+ maxk = max(topk)
133
+ if output.shape[-1] < maxk:
134
+ print(f"[WARNING] Less than {maxk} predictions available. Using {output.shape[-1]} for topk.")
135
+
136
+ maxk = min(maxk, output.shape[-1])
137
+ batch_size = target.size(0)
138
+
139
+ # Take topk along the last dimension.
140
+ _, pred = output.topk(maxk, -1, True, True) # (N, T, topk)
141
+
142
+ mask = (target != padding).type(target.dtype)
143
+ target_expand = target[..., None].expand_as(pred)
144
+ correct = pred.eq(target_expand)
145
+ correct = correct * mask[..., None].expand_as(correct)
146
+
147
+ res = []
148
+ for k in topk:
149
+ correct_k = correct[..., :k].reshape(-1).float().sum(0, keepdim=True)
150
+ res.append(correct_k.mul_(100.0 / mask.sum()))
151
+ return res
152
+
153
+
154
+ def get_params_count(model, max_name_len: int = 60):
155
+ params = [(name[:max_name_len], p.numel(), str(tuple(p.shape)), p.requires_grad) for name, p in model.named_parameters()]
156
+ total_trainable_params = sum([x[1] for x in params if x[-1]])
157
+ total_nontrainable_params = sum([x[1] for x in params if not x[-1]])
158
+ return params, total_trainable_params, total_nontrainable_params
159
+
160
+
161
+ def get_params_count_str(model, max_name_len: int = 60):
162
+ padding = 70 # Hardcoded depending on desired amount of padding and separators.
163
+ params, total_trainable_params, total_nontrainable_params = get_params_count(model, max_name_len)
164
+ param_counts_text = ''
165
+ param_counts_text += '=' * (max_name_len + padding) + '\n'
166
+ param_counts_text += f'| {"Module":<{max_name_len}} | {"Trainable":<10} | {"Shape":>15} | {"Param Count":>12} |\n'
167
+ param_counts_text += '-' * (max_name_len + padding) + '\n'
168
+ for name, param_count, shape, trainable in params:
169
+ param_counts_text += f'| {name:<{max_name_len}} | {"True" if trainable else "False":<10} | {shape:>15} | {param_count:>12,} |\n'
170
+ param_counts_text += '-' * (max_name_len + padding) + '\n'
171
+ param_counts_text += f'| {"Total trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_trainable_params:>12,} |\n'
172
+ param_counts_text += f'| {"Total non-trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_nontrainable_params:>12,} |\n'
173
+ param_counts_text += '=' * (max_name_len + padding) + '\n'
174
+ return param_counts_text
175
+
176
+
177
+ class Summary(Enum):
178
+ NONE = 0
179
+ AVERAGE = 1
180
+ SUM = 2
181
+ COUNT = 3
182
+
183
+
184
+ class ProgressMeter(object):
185
+ def __init__(self, num_batches, meters, prefix=""):
186
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
187
+ self.meters = meters
188
+ self.prefix = prefix
189
+
190
+ def display(self, batch):
191
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
192
+ entries += [str(meter) for meter in self.meters]
193
+ print('\t'.join(entries))
194
+
195
+ def display_summary(self):
196
+ entries = [" *"]
197
+ entries += [meter.summary() for meter in self.meters]
198
+ print(' '.join(entries))
199
+
200
+ def _get_batch_fmtstr(self, num_batches):
201
+ num_digits = len(str(num_batches // 1))
202
+ fmt = '{:' + str(num_digits) + 'd}'
203
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
204
+
205
+
206
+ class AverageMeter(object):
207
+ """Computes and stores the average and current value"""
208
+ def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
209
+ self.name = name
210
+ self.fmt = fmt
211
+ self.summary_type = summary_type
212
+ self.reset()
213
+
214
+ def reset(self):
215
+ self.val = 0
216
+ self.avg = 0
217
+ self.sum = 0
218
+ self.count = 0
219
+
220
+ def update(self, val, n=1):
221
+ self.val = val
222
+ self.sum += val * n
223
+ self.count += n
224
+ self.avg = self.sum / self.count
225
+
226
+ def all_reduce(self):
227
+ device = "cuda" if torch.cuda.is_available() else "cpu"
228
+ total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
229
+ dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
230
+ self.sum, self.count = total.tolist()
231
+ self.avg = self.sum / self.count
232
+
233
+ def __str__(self):
234
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
235
+ return fmtstr.format(**self.__dict__)
236
+
237
+ def summary(self):
238
+ fmtstr = ''
239
+ if self.summary_type is Summary.NONE:
240
+ fmtstr = ''
241
+ elif self.summary_type is Summary.AVERAGE:
242
+ fmtstr = '{name} {avg:.3f}'
243
+ elif self.summary_type is Summary.SUM:
244
+ fmtstr = '{name} {sum:.3f}'
245
+ elif self.summary_type is Summary.COUNT:
246
+ fmtstr = '{name} {count:.3f}'
247
+ else:
248
+ raise ValueError('invalid summary type %r' % self.summary_type)
249
+
250
+ return fmtstr.format(**self.__dict__)