Spaces:
Build error
Build error
Add app
Browse files- .gitignore +1 -0
- README.md +0 -13
- app.py +107 -0
- fromage/__init__.py +0 -0
- fromage/models.py +658 -0
- fromage/utils.py +250 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.DS_Store
|
README.md
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
---
|
2 |
-
title: Fromage
|
3 |
-
emoji: 📚
|
4 |
-
colorFrom: purple
|
5 |
-
colorTo: yellow
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.18.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: apache-2.0
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
from fromage import models
|
7 |
+
from fromage import utils
|
8 |
+
import gradio as gr
|
9 |
+
import huggingface_hub
|
10 |
+
import tempfile
|
11 |
+
|
12 |
+
|
13 |
+
class FromageChatBot:
|
14 |
+
def __init__(self):
|
15 |
+
# Download model from HF Hub.
|
16 |
+
huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar')
|
17 |
+
huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='model_args.json')
|
18 |
+
huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='cc3m_embeddings.pkl')
|
19 |
+
self.model = models.load_fromage('./')
|
20 |
+
self.chat_history = ''
|
21 |
+
self.input_image = None
|
22 |
+
|
23 |
+
|
24 |
+
def reset(self):
|
25 |
+
self.chat_history = ""
|
26 |
+
self.input_image = None
|
27 |
+
return [], []
|
28 |
+
|
29 |
+
|
30 |
+
def upload_image(self, state, image_input):
|
31 |
+
state += [(f"![](/file={image_input.name})", ":)")]
|
32 |
+
self.input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB')
|
33 |
+
return state, state
|
34 |
+
|
35 |
+
|
36 |
+
def save_image_to_local(self, image: Image.Image):
|
37 |
+
# TODO(jykoh): Update so the url path is used, to prevent repeat saving.
|
38 |
+
filename = next(tempfile._get_candidate_names()) + '.png'
|
39 |
+
image.save(filename)
|
40 |
+
return filename
|
41 |
+
|
42 |
+
|
43 |
+
def generate_for_prompt(self, input_text, state, ret_scale_factor, num_ims, num_words, temp):
|
44 |
+
input_prompt = 'Q: ' + input_text + '\nA:'
|
45 |
+
self.chat_history += input_prompt
|
46 |
+
|
47 |
+
# If an image was uploaded, prepend it to the model.
|
48 |
+
model_inputs = None
|
49 |
+
if self.input_image is not None:
|
50 |
+
model_inputs = [self.input_image, self.chat_history]
|
51 |
+
else:
|
52 |
+
model_inputs = [self.chat_history]
|
53 |
+
|
54 |
+
model_outputs = self.model.generate_for_images_and_texts(model_inputs, max_num_rets=num_ims, num_words=num_words, ret_scale_factor=ret_scale_factor, temperature=temp)
|
55 |
+
|
56 |
+
im_names = []
|
57 |
+
response = ''
|
58 |
+
text_outputs = []
|
59 |
+
for output in model_outputs:
|
60 |
+
if type(output) == str:
|
61 |
+
text_outputs.append(output)
|
62 |
+
response += output
|
63 |
+
elif type(output) == list:
|
64 |
+
for image in output:
|
65 |
+
filename = self.save_image_to_local(image)
|
66 |
+
response += f'<img src="/file={filename}">'
|
67 |
+
elif type(output) == Image.Image:
|
68 |
+
filename = self.save_image_to_local(output)
|
69 |
+
response += f'<img src="/file={filename}">'
|
70 |
+
|
71 |
+
self.chat_history += ' '.join(text_output)
|
72 |
+
if self.chat_history[-1] != '\n':
|
73 |
+
self.chat_history += '\n'
|
74 |
+
self.input_image = None
|
75 |
+
|
76 |
+
state.append((input_text, response))
|
77 |
+
return state, state
|
78 |
+
|
79 |
+
|
80 |
+
def launch(self):
|
81 |
+
with gr.Blocks(css="#fromage-space {height:600px; overflow-y:auto;}") as demo:
|
82 |
+
chatbot = gr.Chatbot(elem_id="fromage-space")
|
83 |
+
gr_state = gr.State([])
|
84 |
+
|
85 |
+
with gr.Row():
|
86 |
+
with gr.Column(scale=0.85):
|
87 |
+
text_input = gr.Textbox(show_label=False, placeholder="Upload an image [optional]. Then enter a text prompt, and press enter!").style(container=False)
|
88 |
+
with gr.Column(scale=0.15, min_width=0):
|
89 |
+
image_btn = gr.UploadButton("Image", file_types=["image"])
|
90 |
+
|
91 |
+
with gr.Row():
|
92 |
+
with gr.Column(scale=0.20, min_width=0):
|
93 |
+
clear_btn = gr.Button("Clear")
|
94 |
+
ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
|
95 |
+
max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
|
96 |
+
gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
|
97 |
+
gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
|
98 |
+
|
99 |
+
text_input.submit(self.generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
|
100 |
+
image_btn.upload(self.upload_image, [gr_state, image_btn], [gr_state, chatbot])
|
101 |
+
clear_btn.click(self.reset, [], [gr_state, chatbot])
|
102 |
+
|
103 |
+
demo.launch(share=False, server_name="0.0.0.0")
|
104 |
+
|
105 |
+
|
106 |
+
chatbot = FromageChatBot()
|
107 |
+
chatbot.launch()
|
fromage/__init__.py
ADDED
File without changes
|
fromage/models.py
ADDED
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, List, Optional, Tuple, Union
|
2 |
+
from collections import namedtuple
|
3 |
+
import json
|
4 |
+
import glob
|
5 |
+
import math
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
from torch import Tensor
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from einops import rearrange
|
13 |
+
from functools import partial
|
14 |
+
import pickle as pkl
|
15 |
+
from PIL import Image, UnidentifiedImageError
|
16 |
+
|
17 |
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
18 |
+
from transformers import OPTForCausalLM, GPT2Tokenizer
|
19 |
+
from transformers import CLIPVisionModel, CLIPVisionConfig
|
20 |
+
|
21 |
+
from fromage import utils
|
22 |
+
|
23 |
+
|
24 |
+
class FrozenArgs:
|
25 |
+
freeze_lm: bool = True
|
26 |
+
freeze_vm: bool = True
|
27 |
+
opt_version: str = 'facebook/opt-6.7b'
|
28 |
+
visual_encoder: str = 'openai/clip-vit-large-patch14'
|
29 |
+
n_visual_tokens: int = 1
|
30 |
+
image_embed_dropout_prob: float = 0.0
|
31 |
+
task: str = 'captioning'
|
32 |
+
shared_emb_dim: Optional[int] = 256
|
33 |
+
text_emb_layers: List[int] = [-1]
|
34 |
+
retrieval_token_idx: int = 0
|
35 |
+
|
36 |
+
|
37 |
+
class FromageModel(nn.Module):
|
38 |
+
def __init__(self, tokenizer, args: FrozenArgs = FrozenArgs()):
|
39 |
+
super().__init__()
|
40 |
+
self.tokenizer = tokenizer
|
41 |
+
self.feature_extractor = utils.get_feature_extractor_for_model(args.visual_encoder, train=False)
|
42 |
+
self.image_token = self.tokenizer.cls_token_id
|
43 |
+
assert args.text_emb_layers != set(args.text_emb_layers), 'text_emb_layers not unique'
|
44 |
+
self.args = args
|
45 |
+
|
46 |
+
opt_version = args.opt_version
|
47 |
+
visual_encoder = args.visual_encoder
|
48 |
+
n_visual_tokens = args.n_visual_tokens
|
49 |
+
print(f"Using {opt_version} for the language model.")
|
50 |
+
print(f"Using {visual_encoder} for the visual model with {n_visual_tokens} visual tokens.")
|
51 |
+
|
52 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
53 |
+
|
54 |
+
if 'facebook/opt' in opt_version:
|
55 |
+
self.lm = OPTForCausalLM.from_pretrained(opt_version)
|
56 |
+
else:
|
57 |
+
raise NotImplementedError
|
58 |
+
|
59 |
+
self.opt_version = opt_version
|
60 |
+
|
61 |
+
if self.args.freeze_lm:
|
62 |
+
self.lm.eval()
|
63 |
+
print("Freezing the LM.")
|
64 |
+
for param in self.lm.parameters():
|
65 |
+
param.requires_grad = False
|
66 |
+
else:
|
67 |
+
self.lm.train()
|
68 |
+
|
69 |
+
self.retrieval_token_idx = args.retrieval_token_idx
|
70 |
+
print(f'Initializing embedding for the retrieval token [RET] (id = {self.retrieval_token_idx}).')
|
71 |
+
self.lm.resize_token_embeddings(len(tokenizer))
|
72 |
+
|
73 |
+
self.input_embeddings = self.lm.get_input_embeddings()
|
74 |
+
|
75 |
+
print("Restoring pretrained weights for the visual model.")
|
76 |
+
if 'clip' in visual_encoder:
|
77 |
+
self.visual_model = CLIPVisionModel.from_pretrained(visual_encoder)
|
78 |
+
else:
|
79 |
+
self.visual_model = AutoModel.from_pretrained(visual_encoder)
|
80 |
+
|
81 |
+
if 'clip' in visual_encoder:
|
82 |
+
hidden_size = self.visual_model.config.hidden_size
|
83 |
+
else:
|
84 |
+
raise NotImplementedError
|
85 |
+
|
86 |
+
if self.args.freeze_vm:
|
87 |
+
print("Freezing the VM.")
|
88 |
+
self.visual_model.eval()
|
89 |
+
for param in self.visual_model.parameters():
|
90 |
+
param.requires_grad = False
|
91 |
+
else:
|
92 |
+
self.visual_model.train()
|
93 |
+
|
94 |
+
self.visual_model_name = visual_encoder
|
95 |
+
|
96 |
+
embedding_dim = self.input_embeddings.embedding_dim * self.args.n_visual_tokens
|
97 |
+
self.text_hidden_fcs = nn.ModuleList([])
|
98 |
+
if self.args.shared_emb_dim is None:
|
99 |
+
if len(self.args.text_emb_layers) == 1:
|
100 |
+
if (self.args.text_emb_layers[0] in [-1, self.lm.config.num_hidden_layers]) and ('bert' not in opt_version):
|
101 |
+
out_dim = self.lm.config.word_embed_proj_dim
|
102 |
+
else:
|
103 |
+
out_dim = self.lm.config.hidden_size
|
104 |
+
else:
|
105 |
+
if (-1 in self.args.text_emb_layers) or (self.lm.config.num_hidden_layers in self.args.text_emb_layers) \
|
106 |
+
and (self.lm.config.word_embed_proj_dim != self.lm.config.hidden_size):
|
107 |
+
raise ValueError('No projection dim specified but model uses last output layer and an intermediate one (which have different dims).')
|
108 |
+
else:
|
109 |
+
out_dim = self.lm.config.hidden_size
|
110 |
+
else:
|
111 |
+
out_dim = self.args.shared_emb_dim
|
112 |
+
|
113 |
+
for layer_idx in self.args.text_emb_layers:
|
114 |
+
if (layer_idx == -1 or layer_idx == self.lm.config.num_hidden_layers) and ('bert' not in opt_version):
|
115 |
+
in_dim = self.lm.config.word_embed_proj_dim
|
116 |
+
|
117 |
+
text_fc = [nn.Linear(in_dim, out_dim), nn.Dropout(self.args.text_embed_dropout_prob)]
|
118 |
+
self.text_hidden_fcs.append(nn.Sequential(*text_fc))
|
119 |
+
|
120 |
+
elif layer_idx < self.lm.config.num_hidden_layers:
|
121 |
+
text_fc = [nn.Linear(self.lm.config.hidden_size, out_dim), nn.Dropout(self.args.text_embed_dropout_prob)]
|
122 |
+
self.text_hidden_fcs.append(nn.Sequential(*text_fc))
|
123 |
+
else:
|
124 |
+
raise ValueError(f'Embedding of layer {layer_idx} was requested but model only has {self.lm.config.num_hidden_layers} layers.')
|
125 |
+
|
126 |
+
self.visual_embeddings = nn.Linear(hidden_size, embedding_dim)
|
127 |
+
self.visual_fc = nn.Linear(hidden_size, out_dim)
|
128 |
+
|
129 |
+
self.image_dropout = nn.Dropout(self.args.image_embed_dropout_prob)
|
130 |
+
|
131 |
+
|
132 |
+
def get_visual_embs(self, pixel_values: torch.FloatTensor, mode: str = 'captioning'):
|
133 |
+
if mode not in ['captioning', 'retrieval']:
|
134 |
+
raise ValueError(f'mode should be one of ["caption", "retrieval"], got {mode} instead.')
|
135 |
+
|
136 |
+
# Extract visual embeddings from the vision encoder.
|
137 |
+
if 'clip' in self.visual_model_name:
|
138 |
+
outputs = self.visual_model(pixel_values)
|
139 |
+
encoder_outputs = outputs.pooler_output
|
140 |
+
else:
|
141 |
+
raise NotImplementedError
|
142 |
+
|
143 |
+
# Use the correct fc based on function argument.
|
144 |
+
if mode == 'captioning':
|
145 |
+
visual_embs = self.visual_embeddings(encoder_outputs) # (2, D * n_visual_tokens)
|
146 |
+
visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], self.args.n_visual_tokens, -1))
|
147 |
+
elif mode == 'retrieval':
|
148 |
+
visual_embs = self.visual_fc(encoder_outputs) # (2, D * n_visual_tokens)
|
149 |
+
visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], 1, -1))
|
150 |
+
else:
|
151 |
+
raise NotImplementedError
|
152 |
+
|
153 |
+
visual_embs = self.image_dropout(visual_embs)
|
154 |
+
return visual_embs
|
155 |
+
|
156 |
+
|
157 |
+
def train(self, mode=True):
|
158 |
+
super(FromageModel, self).train(mode=mode)
|
159 |
+
# Overwrite train() to ensure Frozen models remain frozen.
|
160 |
+
if self.args.freeze_lm:
|
161 |
+
self.lm.eval()
|
162 |
+
if self.args.freeze_vm:
|
163 |
+
self.visual_model.eval()
|
164 |
+
|
165 |
+
|
166 |
+
def forward(
|
167 |
+
self,
|
168 |
+
pixel_values: torch.FloatTensor,
|
169 |
+
labels: torch.LongTensor,
|
170 |
+
caption_len: torch.LongTensor,
|
171 |
+
mode: str = 'captioning',
|
172 |
+
concat_captions: bool = False,
|
173 |
+
input_prefix: Optional[str] = None,
|
174 |
+
inference: bool = False,
|
175 |
+
):
|
176 |
+
visual_embs = self.get_visual_embs(pixel_values, mode)
|
177 |
+
|
178 |
+
batch_size, vis_seq_len, _ = visual_embs.shape # vis_seq_len = n_visual_tokens
|
179 |
+
if labels is not None:
|
180 |
+
assert labels.shape[0] == batch_size, (visual_embs.shape, labels.shape)
|
181 |
+
|
182 |
+
input_embs = self.input_embeddings(labels) # (N, T, D)
|
183 |
+
|
184 |
+
last_embedding_idx = caption_len - 1 # -1 to retrieve the token before the eos token
|
185 |
+
|
186 |
+
if input_prefix is not None:
|
187 |
+
prompt_ids = self.tokenizer(input_prefix, add_special_tokens=False, return_tensors="pt").input_ids
|
188 |
+
prompt_ids = prompt_ids.to(visual_embs.device)
|
189 |
+
prompt_embs = self.input_embeddings(prompt_ids)
|
190 |
+
prompt_embs = prompt_embs.repeat(batch_size, 1, 1)
|
191 |
+
assert prompt_embs.shape[0] == batch_size, prompt_embs.shape
|
192 |
+
assert prompt_embs.shape[2] == input_embs.shape[2], prompt_embs.shape
|
193 |
+
assert len(prompt_embs.shape) == 3, prompt_embs.shape
|
194 |
+
|
195 |
+
if mode == 'captioning':
|
196 |
+
# Concat to text embeddings.
|
197 |
+
condition_seq_len = 0
|
198 |
+
if input_prefix is None:
|
199 |
+
# Just add visual embeddings.
|
200 |
+
input_embs = torch.cat([visual_embs, input_embs], axis=1)
|
201 |
+
last_embedding_idx += vis_seq_len
|
202 |
+
condition_seq_len += vis_seq_len
|
203 |
+
full_labels = torch.zeros(visual_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100
|
204 |
+
else:
|
205 |
+
# Add visual and prompt embeddings.
|
206 |
+
prefix_embs = torch.cat([visual_embs, prompt_embs], axis=1)
|
207 |
+
input_embs = torch.cat([prefix_embs, input_embs], axis=1)
|
208 |
+
|
209 |
+
last_embedding_idx += prefix_embs.shape[1]
|
210 |
+
condition_seq_len += prefix_embs.shape[1]
|
211 |
+
full_labels = torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100
|
212 |
+
|
213 |
+
# Mask out embedding tokens in the labels.
|
214 |
+
full_labels = torch.cat([full_labels, labels], axis=1)
|
215 |
+
|
216 |
+
pad_idx = []
|
217 |
+
|
218 |
+
for label in full_labels:
|
219 |
+
for k, token in enumerate(label):
|
220 |
+
# Mask out retrieval token if it exists.
|
221 |
+
if token in [self.tokenizer.pad_token_id, self.retrieval_token_idx]:
|
222 |
+
label[k:] = -100
|
223 |
+
pad_idx.append(k)
|
224 |
+
break
|
225 |
+
if k == len(label) - 1: # No padding found.
|
226 |
+
pad_idx.append(k + 1)
|
227 |
+
assert len(pad_idx) == batch_size, (len(pad_idx), batch_size)
|
228 |
+
|
229 |
+
bs, seq_len, embs_dim = input_embs.shape
|
230 |
+
if concat_captions:
|
231 |
+
assert len(input_embs.shape) == 3, input_embs
|
232 |
+
assert len(full_labels.shape) == 2, full_labels
|
233 |
+
assert batch_size % 2 == 0
|
234 |
+
all_concat_input_embs = []
|
235 |
+
all_concat_labels = []
|
236 |
+
|
237 |
+
# Rearrange embeddings and labels (and their padding) to concatenate captions.
|
238 |
+
for i in range(batch_size // 2):
|
239 |
+
first_idx = i * 2
|
240 |
+
second_idx = first_idx + 1
|
241 |
+
first_emb = input_embs[first_idx, :pad_idx[first_idx], :]
|
242 |
+
first_labels = full_labels[first_idx, :pad_idx[first_idx]]
|
243 |
+
first_padding = input_embs[first_idx, pad_idx[first_idx]:, :]
|
244 |
+
first_labels_padding = full_labels[first_idx, pad_idx[first_idx]:]
|
245 |
+
|
246 |
+
second_emb = input_embs[second_idx, :pad_idx[second_idx], :]
|
247 |
+
second_labels = full_labels[second_idx, :pad_idx[second_idx]]
|
248 |
+
second_padding = input_embs[second_idx, pad_idx[second_idx]:, :]
|
249 |
+
second_labels_padding = full_labels[second_idx, pad_idx[second_idx]:]
|
250 |
+
|
251 |
+
assert torch.all(first_labels_padding == -100), first_labels_padding
|
252 |
+
assert torch.all(second_labels_padding == -100), second_labels_padding
|
253 |
+
concat_input_embs = torch.cat([first_emb, second_emb, first_padding, second_padding], axis=0) # (T*2, 768)
|
254 |
+
concat_labels = torch.cat([first_labels, second_labels, first_labels_padding, second_labels_padding], axis=0) # (T*2, 768)
|
255 |
+
all_concat_input_embs.append(concat_input_embs)
|
256 |
+
all_concat_labels.append(concat_labels)
|
257 |
+
|
258 |
+
# Pad to max length.
|
259 |
+
input_embs = torch.stack(all_concat_input_embs, axis=0) # (N/2, T*2, 768)
|
260 |
+
full_labels = torch.stack(all_concat_labels, axis=0) # (N/2, T*2, 768)
|
261 |
+
assert input_embs.shape == (bs // 2, seq_len * 2, embs_dim), input_embs.shape
|
262 |
+
assert full_labels.shape == (bs // 2, seq_len * 2), full_labels.shape
|
263 |
+
|
264 |
+
output = self.lm(inputs_embeds=input_embs,
|
265 |
+
labels=full_labels,
|
266 |
+
output_hidden_states=True)
|
267 |
+
elif mode == 'retrieval':
|
268 |
+
full_labels = torch.clone(labels)
|
269 |
+
if input_prefix is not None:
|
270 |
+
print(f'Adding prefix "{input_prefix}" to retrieval.')
|
271 |
+
# Add prompt embeddings.
|
272 |
+
prefix_embs = prompt_embs
|
273 |
+
input_embs = torch.cat([prefix_embs, input_embs], axis=1)
|
274 |
+
last_embedding_idx += prefix_embs.shape[1]
|
275 |
+
full_labels = torch.cat([
|
276 |
+
torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(labels.device) - 100,
|
277 |
+
full_labels
|
278 |
+
], axis=1)
|
279 |
+
|
280 |
+
pad_idx = []
|
281 |
+
for label in full_labels:
|
282 |
+
for k, token in enumerate(label):
|
283 |
+
if token == self.tokenizer.pad_token_id:
|
284 |
+
label[k:] = -100
|
285 |
+
pad_idx.append(k)
|
286 |
+
break
|
287 |
+
if k == len(label) - 1: # No padding found.
|
288 |
+
pad_idx.append(k + 1)
|
289 |
+
assert len(pad_idx) == batch_size, (len(pad_idx), batch_size)
|
290 |
+
|
291 |
+
output = self.lm(inputs_embeds=input_embs,
|
292 |
+
labels=full_labels,
|
293 |
+
output_hidden_states=True)
|
294 |
+
else:
|
295 |
+
raise NotImplementedError
|
296 |
+
|
297 |
+
last_embedding = None
|
298 |
+
last_output_logit = None
|
299 |
+
hidden_states = []
|
300 |
+
|
301 |
+
if mode == 'retrieval':
|
302 |
+
if self.args.shared_emb_dim is not None:
|
303 |
+
for idx, fc_layer in zip(self.args.text_emb_layers, self.text_hidden_fcs):
|
304 |
+
hidden_states.append(fc_layer(output.hidden_states[idx])) # (N, seq_len, 2048)
|
305 |
+
else:
|
306 |
+
for idx in self.args.text_emb_layers:
|
307 |
+
hidden_states.append(output.hidden_states[idx])
|
308 |
+
|
309 |
+
# Add hidden states together.
|
310 |
+
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
|
311 |
+
|
312 |
+
if not concat_captions:
|
313 |
+
last_embedding = torch.stack([last_hidden_state[i, last_embedding_idx[i], :] for i in range(batch_size)], axis=0) # (N, D)
|
314 |
+
last_output_logit = torch.stack([output.logits[i, last_embedding_idx[i] - 1, :] for i in range(batch_size)], axis=0) # (N, D)
|
315 |
+
else:
|
316 |
+
# Concatenate two captioning examples together.
|
317 |
+
all_last_embedding = []
|
318 |
+
all_last_output_logit = []
|
319 |
+
for i in range(batch_size // 2):
|
320 |
+
first_last_embedding_idx, second_last_embedding_idx = all_last_embedding_idx[i]
|
321 |
+
first_last_embedding = last_hidden_state[i, first_last_embedding_idx, :] # (N, D)
|
322 |
+
first_last_output_logit = output.logits[i, first_last_embedding_idx - 1, :] # (N, D)
|
323 |
+
second_last_embedding = last_hidden_state[i, second_last_embedding_idx, :] # (N, D)
|
324 |
+
second_last_output_logit = output.logits[i, second_last_embedding_idx - 1, :] # (N, D)
|
325 |
+
all_last_embedding.append(first_last_embedding)
|
326 |
+
all_last_embedding.append(second_last_embedding)
|
327 |
+
all_last_output_logit.append(first_last_output_logit)
|
328 |
+
all_last_output_logit.append(second_last_output_logit)
|
329 |
+
|
330 |
+
last_embedding = torch.stack(all_last_embedding)
|
331 |
+
last_output_logit = torch.stack(all_last_output_logit)
|
332 |
+
|
333 |
+
# Compute retrieval loss.
|
334 |
+
assert visual_embs.shape[1] == 1, visual_embs.shape
|
335 |
+
visual_embs = visual_embs[:, 0, :]
|
336 |
+
visual_embs = visual_embs / visual_embs.norm(dim=1, keepdim=True)
|
337 |
+
last_embedding = last_embedding / last_embedding.norm(dim=1, keepdim=True)
|
338 |
+
|
339 |
+
# cosine similarity as logits
|
340 |
+
logit_scale = self.logit_scale.exp()
|
341 |
+
visual_embs = logit_scale * visual_embs
|
342 |
+
elif mode == 'captioning':
|
343 |
+
pass
|
344 |
+
else:
|
345 |
+
raise NotImplementedError
|
346 |
+
|
347 |
+
return output, full_labels, last_embedding, last_output_logit, visual_embs
|
348 |
+
|
349 |
+
def generate(self, embeddings = torch.FloatTensor, max_len: int = 32,
|
350 |
+
temperature: float = 0.0, top_p: float = 1.0, min_word_tokens: int = 0,
|
351 |
+
ret_scale_factor: float = 1.0, filter_value: float = -float('Inf')):
|
352 |
+
"""Runs greedy decoding and returns generated captions.
|
353 |
+
|
354 |
+
Args:
|
355 |
+
embeddings: Input condition that the model uses for autoregressive generation.
|
356 |
+
max_len: Maximum number of tokens to generate.
|
357 |
+
temperature: Used to modulate logit distribution.
|
358 |
+
top_p: If set to < 1, the smallest set of tokens with highest probabilities that add up to top_p or higher are kept for generation.
|
359 |
+
min_word_tokens: Minimum number of words to generate before allowing a [RET] output.
|
360 |
+
ret_scale_factor: Proportion to scale [RET] token logits by. A higher value may increase the probability of the model generating [RET] outputs.
|
361 |
+
filter_value: Value to assign to tokens that should never be generated.
|
362 |
+
Outputs:
|
363 |
+
out: (N, T) int32 sequence of output tokens.
|
364 |
+
output_embeddings: (N, T, 256) sequence of text output embeddings.
|
365 |
+
"""
|
366 |
+
self.lm.eval()
|
367 |
+
|
368 |
+
with torch.no_grad(): # no tracking history
|
369 |
+
batch_size, s, _ = embeddings.shape
|
370 |
+
# init output with image tokens
|
371 |
+
out = None
|
372 |
+
past_key_values = None
|
373 |
+
output_embeddings = []
|
374 |
+
output_logits = []
|
375 |
+
|
376 |
+
for i in range(max_len):
|
377 |
+
if 'opt' in self.opt_version:
|
378 |
+
output = self.lm(inputs_embeds=embeddings, use_cache=False, output_hidden_states=True)
|
379 |
+
else:
|
380 |
+
if i == 0:
|
381 |
+
output = self.lm(inputs_embeds=embeddings, use_cache=True, past_key_values=None, output_hidden_states=True)
|
382 |
+
else:
|
383 |
+
output = self.lm(input_ids=out[:, -1:], use_cache=True, past_key_values=past_key_values, output_hidden_states=True)
|
384 |
+
|
385 |
+
# Collect and sum the hidden states.
|
386 |
+
hidden_states = []
|
387 |
+
if self.args.shared_emb_dim is not None:
|
388 |
+
for idx, fc_layer in zip(self.args.text_emb_layers, self.text_hidden_fcs):
|
389 |
+
hidden_states.append(fc_layer(output.hidden_states[idx])) # (N, seq_len, 2048)
|
390 |
+
else:
|
391 |
+
for idx in self.args.text_emb_layers:
|
392 |
+
hidden_states.append(output.hidden_states[idx])
|
393 |
+
# Add hidden states together.
|
394 |
+
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) # (N, T, 256)
|
395 |
+
last_embedding = last_hidden_state / last_hidden_state.norm(dim=-1, keepdim=True)
|
396 |
+
output_embeddings.append(last_embedding)
|
397 |
+
|
398 |
+
logits = output.logits[:, -1, :] # (N, vocab_size)
|
399 |
+
if top_p == 1.0:
|
400 |
+
logits = logits.cpu()
|
401 |
+
output_logits.append(logits)
|
402 |
+
|
403 |
+
if self.retrieval_token_idx != -1 and self.retrieval_token_idx is not None:
|
404 |
+
if i < min_word_tokens:
|
405 |
+
# Eliminate probability of generating [RET] if this is earlier than min_word_tokens.
|
406 |
+
logits[:, self.retrieval_token_idx] = filter_value
|
407 |
+
else:
|
408 |
+
# Multiply by scaling factor.
|
409 |
+
logits[:, self.retrieval_token_idx] = logits[:, self.retrieval_token_idx] * ret_scale_factor
|
410 |
+
|
411 |
+
past_key_values = output.past_key_values
|
412 |
+
|
413 |
+
if temperature == 0.0:
|
414 |
+
if top_p != 1.0:
|
415 |
+
raise ValueError('top_p cannot be set if temperature is 0 (greedy decoding).')
|
416 |
+
next_token = torch.argmax(logits, keepdim=True, dim=-1) # (N, 1)
|
417 |
+
else:
|
418 |
+
logits = logits / temperature
|
419 |
+
|
420 |
+
# Apply top-p filtering.
|
421 |
+
if top_p < 1.0:
|
422 |
+
assert top_p > 0, f'top_p should be above 0, got {top_p} instead.'
|
423 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True) # (N, D) and (N, D)
|
424 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # (N, D)
|
425 |
+
|
426 |
+
# Remove tokens with cumulative probability above the threshold
|
427 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
428 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
429 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
430 |
+
sorted_indices_to_remove[..., 0] = 0
|
431 |
+
|
432 |
+
for j in range(sorted_indices.shape[0]):
|
433 |
+
indices_to_remove = sorted_indices[j, sorted_indices_to_remove[j, :]]
|
434 |
+
logits[j, indices_to_remove] = filter_value
|
435 |
+
|
436 |
+
token_weights = logits.exp() # (N, vocab_size)
|
437 |
+
next_token = torch.multinomial(token_weights, 1) # (N, 1)
|
438 |
+
|
439 |
+
next_token = next_token.long().to(embeddings.device)
|
440 |
+
if out is not None:
|
441 |
+
out = torch.cat([out, next_token], dim=-1)
|
442 |
+
else:
|
443 |
+
out = next_token
|
444 |
+
|
445 |
+
if 'opt' in self.opt_version:
|
446 |
+
next_embedding = self.input_embeddings(next_token)
|
447 |
+
embeddings = torch.cat([embeddings, next_embedding], dim=1)
|
448 |
+
elif (self.tokenizer.eos_token_id and (next_token == self.tokenizer.eos_token_id).all()):
|
449 |
+
# End of generation.
|
450 |
+
break
|
451 |
+
|
452 |
+
return out, output_embeddings, output_logits
|
453 |
+
|
454 |
+
|
455 |
+
class Fromage(nn.Module):
|
456 |
+
def __init__(self, tokenizer, model_args: Optional[FrozenArgs] = None,
|
457 |
+
path_array: Optional[List[str]] = None, emb_matrix: Optional[torch.tensor] = None):
|
458 |
+
super().__init__()
|
459 |
+
self.model = FromageModel(tokenizer, model_args)
|
460 |
+
self.path_array = path_array
|
461 |
+
self.emb_matrix = emb_matrix
|
462 |
+
|
463 |
+
def __call__(self, images: Tensor, tgt_tokens: Optional[Tensor] = None, caption_len: Optional[Tensor] = None,
|
464 |
+
generate: bool = False, num_words: int = 32, temperature: float = 1.0, top_p: float = 1.0,
|
465 |
+
ret_scale_factor: float = 1.0, min_word_tokens: int = 0,
|
466 |
+
mode: str = 'captioning', concat_captions: bool = False,
|
467 |
+
input_prefix: Optional[str] = None, inference: bool = False) -> Tensor:
|
468 |
+
if generate:
|
469 |
+
return self.model.generate(images, num_words, temperature=temperature, top_p=top_p,
|
470 |
+
min_word_tokens=min_word_tokens, ret_scale_factor=ret_scale_factor)
|
471 |
+
else:
|
472 |
+
output = self.model(
|
473 |
+
pixel_values = images,
|
474 |
+
labels = tgt_tokens,
|
475 |
+
caption_len = caption_len,
|
476 |
+
mode = mode,
|
477 |
+
concat_captions = concat_captions,
|
478 |
+
input_prefix = input_prefix,
|
479 |
+
inference = inference)
|
480 |
+
return output
|
481 |
+
|
482 |
+
def generate_for_images_and_texts(
|
483 |
+
self, prompts: List, num_words: int = 0, ret_scale_factor: float = 1.0, top_p: float = 1.0, temperature: float = 0.0,
|
484 |
+
max_num_rets: int = 1):
|
485 |
+
"""
|
486 |
+
Encode prompts into embeddings.
|
487 |
+
|
488 |
+
Args:
|
489 |
+
prompts: List of interleaved PIL.Image.Image and strings representing input to the model.
|
490 |
+
num_words: Maximum number of words to generate for. If num_words = 0, the model will run its forward pass and return the outputs.
|
491 |
+
ret_scale_factor: Proportion to scale [RET] token logits by. A higher value may increase the probability of the model generating [RET] outputs.
|
492 |
+
top_p: If set to < 1, the smallest set of tokens with highest probabilities that add up to top_p or higher are kept for generation.
|
493 |
+
temperature: Used to modulate logit distribution.
|
494 |
+
max_num_rets: Maximum number of images to return in one generation pass.
|
495 |
+
Returns:
|
496 |
+
return_outputs: List consisting of either str or List[PIL.Image.Image] objects, representing image-text interleaved model outputs.
|
497 |
+
"""
|
498 |
+
input_embs = []
|
499 |
+
input_ids = []
|
500 |
+
add_bos = True
|
501 |
+
|
502 |
+
for i, p in enumerate(prompts):
|
503 |
+
if type(p) == Image.Image:
|
504 |
+
# Encode as image.
|
505 |
+
pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, p)
|
506 |
+
pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype)
|
507 |
+
pixel_values = pixel_values[None, ...]
|
508 |
+
|
509 |
+
visual_embs = self.model.get_visual_embs(pixel_values, mode='captioning') # (1, n_visual_tokens, D)
|
510 |
+
input_embs.append(visual_embs)
|
511 |
+
elif type(p) == str:
|
512 |
+
text_ids = self.model.tokenizer(p, add_special_tokens=True, return_tensors="pt").input_ids.to(self.model.logit_scale.device)
|
513 |
+
if not add_bos:
|
514 |
+
# Remove <bos> tag.
|
515 |
+
text_ids = text_ids[:, 1:]
|
516 |
+
else:
|
517 |
+
# Only add <bos> once.
|
518 |
+
add_bos = False
|
519 |
+
|
520 |
+
text_embs = self.model.input_embeddings(text_ids) # (1, T, D)
|
521 |
+
input_embs.append(text_embs)
|
522 |
+
input_ids.append(text_ids)
|
523 |
+
else:
|
524 |
+
raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.')
|
525 |
+
input_embs = torch.cat(input_embs, dim=1)
|
526 |
+
input_ids = torch.cat(input_ids, dim=1)
|
527 |
+
|
528 |
+
if num_words == 0:
|
529 |
+
generated_ids = input_ids
|
530 |
+
outputs = self.model.lm(inputs_embeds=input_embs, use_cache=False, output_hidden_states=True)
|
531 |
+
# Map outputs to embeddings, so we can retrieve embeddings from the [RET] tokens.
|
532 |
+
out = []
|
533 |
+
for x, fc in zip(self.model.args.text_emb_layers, self.model.text_hidden_fcs):
|
534 |
+
out.append(fc(outputs.hidden_states[x]))
|
535 |
+
embeddings = torch.stack(out, dim=-1).sum(dim=-1)
|
536 |
+
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (N, T, 256)
|
537 |
+
elif num_words > 0:
|
538 |
+
generated_ids, generated_embeddings, _ = self.model.generate(input_embs, num_words,
|
539 |
+
temperature=temperature, top_p=top_p, ret_scale_factor=ret_scale_factor)
|
540 |
+
embeddings = generated_embeddings[-1][:, input_embs.shape[1]:]
|
541 |
+
|
542 |
+
# Truncate to newline.
|
543 |
+
newline_token_id = self.model.tokenizer('\n', add_special_tokens=False).input_ids[0]
|
544 |
+
trunc_idx = 0
|
545 |
+
for j in range(generated_ids.shape[1]):
|
546 |
+
if generated_ids[0, j] == newline_token_id:
|
547 |
+
trunc_idx = j
|
548 |
+
break
|
549 |
+
if trunc_idx > 0:
|
550 |
+
generated_ids = generated_ids[:, :trunc_idx]
|
551 |
+
embeddings = embeddings[:, :trunc_idx]
|
552 |
+
else:
|
553 |
+
raise ValueError
|
554 |
+
|
555 |
+
# Save outputs as an interleaved list.
|
556 |
+
return_outputs = []
|
557 |
+
# Find up to max_num_rets [RET] tokens, and their corresponding scores.
|
558 |
+
all_ret_idx = [i for i, x in enumerate(generated_ids[0, :] == self.model.retrieval_token_idx) if x][:max_num_rets]
|
559 |
+
seen_image_idx = [] # Avoid showing the same image multiple times.
|
560 |
+
|
561 |
+
last_ret_idx = 0
|
562 |
+
if len(all_ret_idx) == 0:
|
563 |
+
# No [RET] tokens.
|
564 |
+
caption = self.model.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
565 |
+
return_outputs.append(utils.truncate_caption(caption))
|
566 |
+
else:
|
567 |
+
for ret_idx in all_ret_idx:
|
568 |
+
ret_emb = embeddings[:, ret_idx, :]
|
569 |
+
scores = self.emb_matrix @ ret_emb.T
|
570 |
+
|
571 |
+
# Downweight seen images.
|
572 |
+
for seen_idx in seen_image_idx:
|
573 |
+
scores[seen_idx, :] -= 1000
|
574 |
+
|
575 |
+
# Get the top 3 images for each image.
|
576 |
+
_, top_image_idx = scores.squeeze().topk(3)
|
577 |
+
image_outputs = []
|
578 |
+
for img_idx in top_image_idx:
|
579 |
+
# Find the first image that does not error out.
|
580 |
+
try:
|
581 |
+
seen_image_idx.append(img_idx)
|
582 |
+
img = utils.get_image_from_url(self.path_array[img_idx])
|
583 |
+
image_outputs.append(img)
|
584 |
+
if len(image_outputs) == max_num_rets:
|
585 |
+
break
|
586 |
+
except UnidentifiedImageError:
|
587 |
+
pass
|
588 |
+
|
589 |
+
caption = self.model.tokenizer.batch_decode(generated_ids[:, last_ret_idx:ret_idx], skip_special_tokens=True)[0]
|
590 |
+
last_ret_idx = ret_idx + 1
|
591 |
+
return_outputs.append(utils.truncate_caption(caption) + ' [RET]')
|
592 |
+
return_outputs.append(image_outputs)
|
593 |
+
|
594 |
+
return return_outputs
|
595 |
+
|
596 |
+
|
597 |
+
def load_fromage(model_dir: str) -> Fromage:
|
598 |
+
model_args_path = os.path.join(model_dir, 'model_args.json')
|
599 |
+
model_ckpt_path = os.path.join(model_dir, 'pretrained_ckpt.pth.tar')
|
600 |
+
embs_paths = [s for s in glob.glob(os.path.join(model_dir, 'cc3m_embeddings*.pkl'))]
|
601 |
+
|
602 |
+
if not os.path.exists(model_args_path):
|
603 |
+
raise ValueError(f'model_args.json does not exist in {model_dir}.')
|
604 |
+
if not os.path.exists(model_ckpt_path):
|
605 |
+
raise ValueError(f'pretrained_ckpt.pth.tar does not exist in {model_dir}.')
|
606 |
+
if len(embs_paths) == 0:
|
607 |
+
raise ValueError(f'cc3m_embeddings_*.pkl files do not exist in {model_dir}.')
|
608 |
+
|
609 |
+
# Load embeddings.
|
610 |
+
# Construct embedding matrix for nearest neighbor lookup.
|
611 |
+
path_array = []
|
612 |
+
emb_matrix = []
|
613 |
+
|
614 |
+
# These were precomputed for all CC3M images with `model.get_visual_embs(image, mode='retrieval')`.
|
615 |
+
for p in embs_paths:
|
616 |
+
with open(p, 'rb') as wf:
|
617 |
+
train_embs_data = pkl.load(wf)
|
618 |
+
path_array.extend(train_embs_data['paths'])
|
619 |
+
emb_matrix.append(train_embs_data['embeddings'])
|
620 |
+
emb_matrix = np.concatenate(emb_matrix, axis=0)
|
621 |
+
|
622 |
+
# Number of paths should be equal to number of embeddings.
|
623 |
+
assert len(path_array) == emb_matrix.shape[0], (len(path_array), emb_matrix.shape[0])
|
624 |
+
|
625 |
+
with open(model_args_path, 'r') as f:
|
626 |
+
model_kwargs = json.load(f)
|
627 |
+
|
628 |
+
# Initialize tokenizer.
|
629 |
+
tokenizer = GPT2Tokenizer.from_pretrained(model_kwargs['opt_version'])
|
630 |
+
tokenizer.pad_token = tokenizer.eos_token
|
631 |
+
# Add special tokens to the model to enable [RET].
|
632 |
+
tokenizer.add_special_tokens({"cls_token": "<|image|>"})
|
633 |
+
tokenizer.add_tokens('[RET]')
|
634 |
+
ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
|
635 |
+
assert len(ret_token_idx) == 1, ret_token_idx
|
636 |
+
model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
|
637 |
+
args = namedtuple('args', model_kwargs)(**model_kwargs)
|
638 |
+
|
639 |
+
# Initialize model for inference.
|
640 |
+
model = Fromage(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix)
|
641 |
+
model = model.eval()
|
642 |
+
model = model.bfloat16()
|
643 |
+
model = model.cuda()
|
644 |
+
|
645 |
+
# Load pretrained linear mappings and [RET] embeddings.
|
646 |
+
checkpoint = torch.load(model_ckpt_path)
|
647 |
+
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
648 |
+
with torch.no_grad():
|
649 |
+
model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].cpu().detach())
|
650 |
+
|
651 |
+
logit_scale = model.model.logit_scale.exp()
|
652 |
+
emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)
|
653 |
+
emb_matrix = emb_matrix / emb_matrix.norm(dim=1, keepdim=True)
|
654 |
+
emb_matrix = logit_scale * emb_matrix
|
655 |
+
model.emb_matrix = emb_matrix
|
656 |
+
|
657 |
+
return model
|
658 |
+
|
fromage/utils.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
import subprocess
|
3 |
+
import sys
|
4 |
+
import shutil
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
from torchvision.transforms import functional as F
|
8 |
+
from torchvision import transforms as T
|
9 |
+
from transformers import AutoFeatureExtractor
|
10 |
+
from PIL import Image, ImageDraw, ImageFont, ImageOps
|
11 |
+
import requests
|
12 |
+
from io import BytesIO
|
13 |
+
|
14 |
+
import random
|
15 |
+
|
16 |
+
|
17 |
+
def dump_git_status(out_file=sys.stdout, exclude_file_patterns=['*.ipynb', '*.th', '*.sh', '*.txt', '*.json']):
|
18 |
+
"""Logs git status to stdout."""
|
19 |
+
subprocess.call('git rev-parse HEAD', shell=True, stdout=out_file)
|
20 |
+
subprocess.call('echo', shell=True, stdout=out_file)
|
21 |
+
exclude_string = ''
|
22 |
+
subprocess.call('git --no-pager diff -- . {}'.format(exclude_string), shell=True, stdout=out_file)
|
23 |
+
|
24 |
+
|
25 |
+
def get_image_from_url(url: str):
|
26 |
+
response = requests.get(url)
|
27 |
+
img = Image.open(BytesIO(response.content))
|
28 |
+
img = img.resize((224, 224))
|
29 |
+
img = img.convert('RGB')
|
30 |
+
return img
|
31 |
+
|
32 |
+
|
33 |
+
def truncate_caption(caption: str) -> str:
|
34 |
+
"""Truncate captions at periods and newlines."""
|
35 |
+
trunc_index = caption.find('\n') + 1
|
36 |
+
if trunc_index <= 0:
|
37 |
+
trunc_index = caption.find('.') + 1
|
38 |
+
caption = caption[:trunc_index]
|
39 |
+
return caption
|
40 |
+
|
41 |
+
|
42 |
+
def pad_to_size(x, size=256):
|
43 |
+
delta_w = size - x.size[0]
|
44 |
+
delta_h = size - x.size[1]
|
45 |
+
padding = (
|
46 |
+
delta_w // 2,
|
47 |
+
delta_h // 2,
|
48 |
+
delta_w - (delta_w // 2),
|
49 |
+
delta_h - (delta_h // 2),
|
50 |
+
)
|
51 |
+
new_im = ImageOps.expand(x, padding)
|
52 |
+
return new_im
|
53 |
+
|
54 |
+
|
55 |
+
class RandCropResize(object):
|
56 |
+
|
57 |
+
"""
|
58 |
+
Randomly crops, then randomly resizes, then randomly crops again, an image. Mirroring the augmentations from https://arxiv.org/abs/2102.12092
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(self, target_size):
|
62 |
+
self.target_size = target_size
|
63 |
+
|
64 |
+
def __call__(self, img):
|
65 |
+
img = pad_to_size(img, self.target_size)
|
66 |
+
d_min = min(img.size)
|
67 |
+
img = T.RandomCrop(size=d_min)(img)
|
68 |
+
t_min = min(d_min, round(9 / 8 * self.target_size))
|
69 |
+
t_max = min(d_min, round(12 / 8 * self.target_size))
|
70 |
+
t = random.randint(t_min, t_max + 1)
|
71 |
+
img = T.Resize(t)(img)
|
72 |
+
if min(img.size) < 256:
|
73 |
+
img = T.Resize(256)(img)
|
74 |
+
return T.RandomCrop(size=self.target_size)(img)
|
75 |
+
|
76 |
+
|
77 |
+
class SquarePad(object):
|
78 |
+
"""Pads image to square.
|
79 |
+
From https://discuss.pytorch.org/t/how-to-resize-and-pad-in-a-torchvision-transforms-compose/71850/9
|
80 |
+
"""
|
81 |
+
def __call__(self, image):
|
82 |
+
max_wh = max(image.size)
|
83 |
+
p_left, p_top = [(max_wh - s) // 2 for s in image.size]
|
84 |
+
p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])]
|
85 |
+
padding = (p_left, p_top, p_right, p_bottom)
|
86 |
+
return F.pad(image, padding, 0, 'constant')
|
87 |
+
|
88 |
+
|
89 |
+
def create_image_of_text(text: str, width: int = 224, nrows: int = 2, color=(255, 255, 255), font=None) -> torch.Tensor:
|
90 |
+
"""Creates a (3, nrows * 14, width) image of text.
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
cap_img: (3, 14 * nrows, width) image of wrapped text.
|
94 |
+
"""
|
95 |
+
height = 12
|
96 |
+
padding = 5
|
97 |
+
effective_width = width - 2 * padding
|
98 |
+
# Create a black image to draw text on.
|
99 |
+
cap_img = Image.new('RGB', (effective_width * nrows, height), color = (0, 0, 0))
|
100 |
+
draw = ImageDraw.Draw(cap_img)
|
101 |
+
draw.text((0, 0), text, color, font=font or ImageFont.load_default())
|
102 |
+
cap_img = F.convert_image_dtype(F.pil_to_tensor(cap_img), torch.float32) # (3, height, W * nrows)
|
103 |
+
cap_img = torch.split(cap_img, effective_width, dim=-1) # List of nrow elements of shape (3, height, W)
|
104 |
+
cap_img = torch.cat(cap_img, dim=1) # (3, height * nrows, W)
|
105 |
+
# Add zero padding.
|
106 |
+
cap_img = torch.nn.functional.pad(cap_img, [padding, padding, 0, padding])
|
107 |
+
return cap_img
|
108 |
+
|
109 |
+
|
110 |
+
def get_feature_extractor_for_model(model_name: str, image_size: int = 224, train: bool = True):
|
111 |
+
print(f'Using HuggingFace AutoFeatureExtractor for {model_name}.')
|
112 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
|
113 |
+
return feature_extractor
|
114 |
+
|
115 |
+
|
116 |
+
def get_pixel_values_for_model(feature_extractor, img):
|
117 |
+
pixel_values = feature_extractor(
|
118 |
+
img.convert('RGB'),
|
119 |
+
return_tensors="pt").pixel_values[0, ...] # (3, H, W)
|
120 |
+
return pixel_values
|
121 |
+
|
122 |
+
|
123 |
+
def save_checkpoint(state, is_best, filename='checkpoint'):
|
124 |
+
torch.save(state, filename + '.pth.tar')
|
125 |
+
if is_best:
|
126 |
+
shutil.copyfile(filename + '.pth.tar', filename + '_best.pth.tar')
|
127 |
+
|
128 |
+
|
129 |
+
def accuracy(output, target, padding, topk=(1,)):
|
130 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
131 |
+
with torch.no_grad():
|
132 |
+
maxk = max(topk)
|
133 |
+
if output.shape[-1] < maxk:
|
134 |
+
print(f"[WARNING] Less than {maxk} predictions available. Using {output.shape[-1]} for topk.")
|
135 |
+
|
136 |
+
maxk = min(maxk, output.shape[-1])
|
137 |
+
batch_size = target.size(0)
|
138 |
+
|
139 |
+
# Take topk along the last dimension.
|
140 |
+
_, pred = output.topk(maxk, -1, True, True) # (N, T, topk)
|
141 |
+
|
142 |
+
mask = (target != padding).type(target.dtype)
|
143 |
+
target_expand = target[..., None].expand_as(pred)
|
144 |
+
correct = pred.eq(target_expand)
|
145 |
+
correct = correct * mask[..., None].expand_as(correct)
|
146 |
+
|
147 |
+
res = []
|
148 |
+
for k in topk:
|
149 |
+
correct_k = correct[..., :k].reshape(-1).float().sum(0, keepdim=True)
|
150 |
+
res.append(correct_k.mul_(100.0 / mask.sum()))
|
151 |
+
return res
|
152 |
+
|
153 |
+
|
154 |
+
def get_params_count(model, max_name_len: int = 60):
|
155 |
+
params = [(name[:max_name_len], p.numel(), str(tuple(p.shape)), p.requires_grad) for name, p in model.named_parameters()]
|
156 |
+
total_trainable_params = sum([x[1] for x in params if x[-1]])
|
157 |
+
total_nontrainable_params = sum([x[1] for x in params if not x[-1]])
|
158 |
+
return params, total_trainable_params, total_nontrainable_params
|
159 |
+
|
160 |
+
|
161 |
+
def get_params_count_str(model, max_name_len: int = 60):
|
162 |
+
padding = 70 # Hardcoded depending on desired amount of padding and separators.
|
163 |
+
params, total_trainable_params, total_nontrainable_params = get_params_count(model, max_name_len)
|
164 |
+
param_counts_text = ''
|
165 |
+
param_counts_text += '=' * (max_name_len + padding) + '\n'
|
166 |
+
param_counts_text += f'| {"Module":<{max_name_len}} | {"Trainable":<10} | {"Shape":>15} | {"Param Count":>12} |\n'
|
167 |
+
param_counts_text += '-' * (max_name_len + padding) + '\n'
|
168 |
+
for name, param_count, shape, trainable in params:
|
169 |
+
param_counts_text += f'| {name:<{max_name_len}} | {"True" if trainable else "False":<10} | {shape:>15} | {param_count:>12,} |\n'
|
170 |
+
param_counts_text += '-' * (max_name_len + padding) + '\n'
|
171 |
+
param_counts_text += f'| {"Total trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_trainable_params:>12,} |\n'
|
172 |
+
param_counts_text += f'| {"Total non-trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_nontrainable_params:>12,} |\n'
|
173 |
+
param_counts_text += '=' * (max_name_len + padding) + '\n'
|
174 |
+
return param_counts_text
|
175 |
+
|
176 |
+
|
177 |
+
class Summary(Enum):
|
178 |
+
NONE = 0
|
179 |
+
AVERAGE = 1
|
180 |
+
SUM = 2
|
181 |
+
COUNT = 3
|
182 |
+
|
183 |
+
|
184 |
+
class ProgressMeter(object):
|
185 |
+
def __init__(self, num_batches, meters, prefix=""):
|
186 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
187 |
+
self.meters = meters
|
188 |
+
self.prefix = prefix
|
189 |
+
|
190 |
+
def display(self, batch):
|
191 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
192 |
+
entries += [str(meter) for meter in self.meters]
|
193 |
+
print('\t'.join(entries))
|
194 |
+
|
195 |
+
def display_summary(self):
|
196 |
+
entries = [" *"]
|
197 |
+
entries += [meter.summary() for meter in self.meters]
|
198 |
+
print(' '.join(entries))
|
199 |
+
|
200 |
+
def _get_batch_fmtstr(self, num_batches):
|
201 |
+
num_digits = len(str(num_batches // 1))
|
202 |
+
fmt = '{:' + str(num_digits) + 'd}'
|
203 |
+
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
204 |
+
|
205 |
+
|
206 |
+
class AverageMeter(object):
|
207 |
+
"""Computes and stores the average and current value"""
|
208 |
+
def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
|
209 |
+
self.name = name
|
210 |
+
self.fmt = fmt
|
211 |
+
self.summary_type = summary_type
|
212 |
+
self.reset()
|
213 |
+
|
214 |
+
def reset(self):
|
215 |
+
self.val = 0
|
216 |
+
self.avg = 0
|
217 |
+
self.sum = 0
|
218 |
+
self.count = 0
|
219 |
+
|
220 |
+
def update(self, val, n=1):
|
221 |
+
self.val = val
|
222 |
+
self.sum += val * n
|
223 |
+
self.count += n
|
224 |
+
self.avg = self.sum / self.count
|
225 |
+
|
226 |
+
def all_reduce(self):
|
227 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
228 |
+
total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
|
229 |
+
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
|
230 |
+
self.sum, self.count = total.tolist()
|
231 |
+
self.avg = self.sum / self.count
|
232 |
+
|
233 |
+
def __str__(self):
|
234 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
235 |
+
return fmtstr.format(**self.__dict__)
|
236 |
+
|
237 |
+
def summary(self):
|
238 |
+
fmtstr = ''
|
239 |
+
if self.summary_type is Summary.NONE:
|
240 |
+
fmtstr = ''
|
241 |
+
elif self.summary_type is Summary.AVERAGE:
|
242 |
+
fmtstr = '{name} {avg:.3f}'
|
243 |
+
elif self.summary_type is Summary.SUM:
|
244 |
+
fmtstr = '{name} {sum:.3f}'
|
245 |
+
elif self.summary_type is Summary.COUNT:
|
246 |
+
fmtstr = '{name} {count:.3f}'
|
247 |
+
else:
|
248 |
+
raise ValueError('invalid summary type %r' % self.summary_type)
|
249 |
+
|
250 |
+
return fmtstr.format(**self.__dict__)
|