from typing import Callable, List, Optional, Tuple, Union |
from collections import namedtuple |
import json |
import glob |
import math |
import numpy as np |
import os |
import torch |
from torch import Tensor |
import torch.nn as nn |
import torch.nn.functional as F |
from einops import rearrange |
from functools import partial |
import pickle as pkl |
from PIL import Image, UnidentifiedImageError |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM |
from transformers import OPTForCausalLM, GPT2Tokenizer |
from transformers import CLIPVisionModel, CLIPVisionConfig |
from fromage import utils |
class FrozenArgs: |
freeze_lm: bool = True |
freeze_vm: bool = True |
opt_version: str = 'facebook/opt-6.7b' |
visual_encoder: str = 'openai/clip-vit-large-patch14' |
n_visual_tokens: int = 1 |
image_embed_dropout_prob: float = 0.0 |
task: str = 'captioning' |
shared_emb_dim: Optional[int] = 256 |
text_emb_layers: List[int] = [-1] |
retrieval_token_idx: int = 0 |
class FromageModel(nn.Module): |
def __init__(self, tokenizer, args: FrozenArgs = FrozenArgs()): |
super().__init__() |
self.tokenizer = tokenizer |
self.feature_extractor = utils.get_feature_extractor_for_model(args.visual_encoder, train=False) |
self.image_token = self.tokenizer.cls_token_id |
assert args.text_emb_layers != set(args.text_emb_layers), 'text_emb_layers not unique' |
self.args = args |
opt_version = args.opt_version |
visual_encoder = args.visual_encoder |
n_visual_tokens = args.n_visual_tokens |
print(f"Using {opt_version} for the language model.") |
print(f"Using {visual_encoder} for the visual model with {n_visual_tokens} visual tokens.") |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
if 'facebook/opt' in opt_version: |
self.lm = OPTForCausalLM.from_pretrained(opt_version) |
else: |
raise NotImplementedError |
self.opt_version = opt_version |
if self.args.freeze_lm: |
self.lm.eval() |
print("Freezing the LM.") |
for param in self.lm.parameters(): |
param.requires_grad = False |
else: |
self.lm.train() |
self.retrieval_token_idx = args.retrieval_token_idx |
print(f'Initializing embedding for the retrieval token [RET] (id = {self.retrieval_token_idx}).') |
self.lm.resize_token_embeddings(len(tokenizer)) |
self.input_embeddings = self.lm.get_input_embeddings() |
print("Restoring pretrained weights for the visual model.") |
if 'clip' in visual_encoder: |
self.visual_model = CLIPVisionModel.from_pretrained(visual_encoder) |
else: |
self.visual_model = AutoModel.from_pretrained(visual_encoder) |
if 'clip' in visual_encoder: |
hidden_size = self.visual_model.config.hidden_size |
else: |
raise NotImplementedError |
if self.args.freeze_vm: |
print("Freezing the VM.") |
self.visual_model.eval() |
for param in self.visual_model.parameters(): |
param.requires_grad = False |
else: |
self.visual_model.train() |
self.visual_model_name = visual_encoder |
embedding_dim = self.input_embeddings.embedding_dim * self.args.n_visual_tokens |
self.text_hidden_fcs = nn.ModuleList([]) |
if self.args.shared_emb_dim is None: |
if len(self.args.text_emb_layers) == 1: |
if (self.args.text_emb_layers[0] in [-1, self.lm.config.num_hidden_layers]) and ('bert' not in opt_version): |
out_dim = self.lm.config.word_embed_proj_dim |
else: |
out_dim = self.lm.config.hidden_size |
else: |
if (-1 in self.args.text_emb_layers) or (self.lm.config.num_hidden_layers in self.args.text_emb_layers) \ |
and (self.lm.config.word_embed_proj_dim != self.lm.config.hidden_size): |
raise ValueError('No projection dim specified but model uses last output layer and an intermediate one (which have different dims).') |
else: |
out_dim = self.lm.config.hidden_size |
else: |
out_dim = self.args.shared_emb_dim |
for layer_idx in self.args.text_emb_layers: |
if (layer_idx == -1 or layer_idx == self.lm.config.num_hidden_layers) and ('bert' not in opt_version): |
in_dim = self.lm.config.word_embed_proj_dim |
text_fc = [nn.Linear(in_dim, out_dim), nn.Dropout(self.args.text_embed_dropout_prob)] |
self.text_hidden_fcs.append(nn.Sequential(*text_fc)) |
elif layer_idx < self.lm.config.num_hidden_layers: |
text_fc = [nn.Linear(self.lm.config.hidden_size, out_dim), nn.Dropout(self.args.text_embed_dropout_prob)] |
self.text_hidden_fcs.append(nn.Sequential(*text_fc)) |
else: |
raise ValueError(f'Embedding of layer {layer_idx} was requested but model only has {self.lm.config.num_hidden_layers} layers.') |
self.visual_embeddings = nn.Linear(hidden_size, embedding_dim) |
self.visual_fc = nn.Linear(hidden_size, out_dim) |
self.image_dropout = nn.Dropout(self.args.image_embed_dropout_prob) |
def get_visual_embs(self, pixel_values: torch.FloatTensor, mode: str = 'captioning'): |
if mode not in ['captioning', 'retrieval']: |
raise ValueError(f'mode should be one of ["caption", "retrieval"], got {mode} instead.') |
if 'clip' in self.visual_model_name: |
outputs = self.visual_model(pixel_values) |
encoder_outputs = outputs.pooler_output |
else: |
raise NotImplementedError |
if mode == 'captioning': |
visual_embs = self.visual_embeddings(encoder_outputs) |
visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], self.args.n_visual_tokens, -1)) |
elif mode == 'retrieval': |
visual_embs = self.visual_fc(encoder_outputs) |
visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], 1, -1)) |
else: |
raise NotImplementedError |
visual_embs = self.image_dropout(visual_embs) |
return visual_embs |
def train(self, mode=True): |
super(FromageModel, self).train(mode=mode) |
if self.args.freeze_lm: |
self.lm.eval() |
if self.args.freeze_vm: |
self.visual_model.eval() |
def forward( |
self, |
pixel_values: torch.FloatTensor, |
labels: torch.LongTensor, |
caption_len: torch.LongTensor, |
mode: str = 'captioning', |
concat_captions: bool = False, |
input_prefix: Optional[str] = None, |
inference: bool = False, |
): |
visual_embs = self.get_visual_embs(pixel_values, mode) |
batch_size, vis_seq_len, _ = visual_embs.shape |
if labels is not None: |
assert labels.shape[0] == batch_size, (visual_embs.shape, labels.shape) |
input_embs = self.input_embeddings(labels) |
last_embedding_idx = caption_len - 1 |
if input_prefix is not None: |
prompt_ids = self.tokenizer(input_prefix, add_special_tokens=False, return_tensors="pt").input_ids |
prompt_ids = prompt_ids.to(visual_embs.device) |
prompt_embs = self.input_embeddings(prompt_ids) |
prompt_embs = prompt_embs.repeat(batch_size, 1, 1) |
assert prompt_embs.shape[0] == batch_size, prompt_embs.shape |
assert prompt_embs.shape[2] == input_embs.shape[2], prompt_embs.shape |
assert len(prompt_embs.shape) == 3, prompt_embs.shape |
if mode == 'captioning': |
condition_seq_len = 0 |
if input_prefix is None: |
input_embs = torch.cat([visual_embs, input_embs], axis=1) |
last_embedding_idx += vis_seq_len |
condition_seq_len += vis_seq_len |
full_labels = torch.zeros(visual_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100 |
else: |
prefix_embs = torch.cat([visual_embs, prompt_embs], axis=1) |
input_embs = torch.cat([prefix_embs, input_embs], axis=1) |
last_embedding_idx += prefix_embs.shape[1] |
condition_seq_len += prefix_embs.shape[1] |
full_labels = torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100 |
full_labels = torch.cat([full_labels, labels], axis=1) |
pad_idx = [] |
for label in full_labels: |
for k, token in enumerate(label): |
if token in [self.tokenizer.pad_token_id, self.retrieval_token_idx]: |
label[k:] = -100 |
pad_idx.append(k) |
break |
if k == len(label) - 1: |
pad_idx.append(k + 1) |
assert len(pad_idx) == batch_size, (len(pad_idx), batch_size) |
bs, seq_len, embs_dim = input_embs.shape |
if concat_captions: |
assert len(input_embs.shape) == 3, input_embs |
assert len(full_labels.shape) == 2, full_labels |
assert batch_size % 2 == 0 |
all_concat_input_embs = [] |
all_concat_labels = [] |
for i in range(batch_size // 2): |
first_idx = i * 2 |
second_idx = first_idx + 1 |
first_emb = input_embs[first_idx, :pad_idx[first_idx], :] |
first_labels = full_labels[first_idx, :pad_idx[first_idx]] |
first_padding = input_embs[first_idx, pad_idx[first_idx]:, :] |
first_labels_padding = full_labels[first_idx, pad_idx[first_idx]:] |
second_emb = input_embs[second_idx, :pad_idx[second_idx], :] |
second_labels = full_labels[second_idx, :pad_idx[second_idx]] |
second_padding = input_embs[second_idx, pad_idx[second_idx]:, :] |
second_labels_padding = full_labels[second_idx, pad_idx[second_idx]:] |
assert torch.all(first_labels_padding == -100), first_labels_padding |
assert torch.all(second_labels_padding == -100), second_labels_padding |
concat_input_embs = torch.cat([first_emb, second_emb, first_padding, second_padding], axis=0) |
concat_labels = torch.cat([first_labels, second_labels, first_labels_padding, second_labels_padding], axis=0) |
all_concat_input_embs.append(concat_input_embs) |
all_concat_labels.append(concat_labels) |
input_embs = torch.stack(all_concat_input_embs, axis=0) |
full_labels = torch.stack(all_concat_labels, axis=0) |
assert input_embs.shape == (bs // 2, seq_len * 2, embs_dim), input_embs.shape |
assert full_labels.shape == (bs // 2, seq_len * 2), full_labels.shape |
output = self.lm(inputs_embeds=input_embs, |
labels=full_labels, |
output_hidden_states=True) |
elif mode == 'retrieval': |
full_labels = torch.clone(labels) |
if input_prefix is not None: |
print(f'Adding prefix "{input_prefix}" to retrieval.') |
prefix_embs = prompt_embs |
input_embs = torch.cat([prefix_embs, input_embs], axis=1) |
last_embedding_idx += prefix_embs.shape[1] |
full_labels = torch.cat([ |
torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(labels.device) - 100, |
full_labels |
], axis=1) |
pad_idx = [] |
for label in full_labels: |
for k, token in enumerate(label): |
if token == self.tokenizer.pad_token_id: |
label[k:] = -100 |
pad_idx.append(k) |
break |
if k == len(label) - 1: |
pad_idx.append(k + 1) |
assert len(pad_idx) == batch_size, (len(pad_idx), batch_size) |
output = self.lm(inputs_embeds=input_embs, |
labels=full_labels, |
output_hidden_states=True) |
else: |
raise NotImplementedError |
last_embedding = None |
last_output_logit = None |
hidden_states = [] |
if mode == 'retrieval': |
if self.args.shared_emb_dim is not None: |
for idx, fc_layer in zip(self.args.text_emb_layers, self.text_hidden_fcs): |
hidden_states.append(fc_layer(output.hidden_states[idx])) |
else: |
for idx in self.args.text_emb_layers: |
hidden_states.append(output.hidden_states[idx]) |
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) |
if not concat_captions: |
last_embedding = torch.stack([last_hidden_state[i, last_embedding_idx[i], :] for i in range(batch_size)], axis=0) |
last_output_logit = torch.stack([output.logits[i, last_embedding_idx[i] - 1, :] for i in range(batch_size)], axis=0) |
else: |
all_last_embedding = [] |
all_last_output_logit = [] |
for i in range(batch_size // 2): |
first_last_embedding_idx, second_last_embedding_idx = all_last_embedding_idx[i] |
first_last_embedding = last_hidden_state[i, first_last_embedding_idx, :] |
first_last_output_logit = output.logits[i, first_last_embedding_idx - 1, :] |
second_last_embedding = last_hidden_state[i, second_last_embedding_idx, :] |
second_last_output_logit = output.logits[i, second_last_embedding_idx - 1, :] |
all_last_embedding.append(first_last_embedding) |
all_last_embedding.append(second_last_embedding) |
all_last_output_logit.append(first_last_output_logit) |
all_last_output_logit.append(second_last_output_logit) |
last_embedding = torch.stack(all_last_embedding) |
last_output_logit = torch.stack(all_last_output_logit) |
assert visual_embs.shape[1] == 1, visual_embs.shape |
visual_embs = visual_embs[:, 0, :] |
visual_embs = visual_embs / visual_embs.norm(dim=1, keepdim=True) |
last_embedding = last_embedding / last_embedding.norm(dim=1, keepdim=True) |
logit_scale = self.logit_scale.exp() |
visual_embs = logit_scale * visual_embs |
elif mode == 'captioning': |
pass |
else: |
raise NotImplementedError |
return output, full_labels, last_embedding, last_output_logit, visual_embs |
def generate(self, embeddings = torch.FloatTensor, max_len: int = 32, |
temperature: float = 0.0, top_p: float = 1.0, min_word_tokens: int = 0, |
ret_scale_factor: float = 1.0, filter_value: float = -float('Inf')): |
"""Runs greedy decoding and returns generated captions. |
Args: |
embeddings: Input condition that the model uses for autoregressive generation. |
max_len: Maximum number of tokens to generate. |
temperature: Used to modulate logit distribution. |
top_p: If set to < 1, the smallest set of tokens with highest probabilities that add up to top_p or higher are kept for generation. |
min_word_tokens: Minimum number of words to generate before allowing a [RET] output. |
ret_scale_factor: Proportion to scale [RET] token logits by. A higher value may increase the probability of the model generating [RET] outputs. |
filter_value: Value to assign to tokens that should never be generated. |
Outputs: |
out: (N, T) int32 sequence of output tokens. |
output_embeddings: (N, T, 256) sequence of text output embeddings. |
""" |
self.lm.eval() |
with torch.no_grad(): |
batch_size, s, _ = embeddings.shape |
out = None |
past_key_values = None |
output_embeddings = [] |
output_logits = [] |
for i in range(max_len): |
if 'opt' in self.opt_version: |
output = self.lm(inputs_embeds=embeddings, use_cache=False, output_hidden_states=True) |
else: |
if i == 0: |
output = self.lm(inputs_embeds=embeddings, use_cache=True, past_key_values=None, output_hidden_states=True) |
else: |
output = self.lm(input_ids=out[:, -1:], use_cache=True, past_key_values=past_key_values, output_hidden_states=True) |
hidden_states = [] |
if self.args.shared_emb_dim is not None: |
for idx, fc_layer in zip(self.args.text_emb_layers, self.text_hidden_fcs): |
hidden_states.append(fc_layer(output.hidden_states[idx])) |
else: |
for idx in self.args.text_emb_layers: |
hidden_states.append(output.hidden_states[idx]) |
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) |
last_embedding = last_hidden_state / last_hidden_state.norm(dim=-1, keepdim=True) |
output_embeddings.append(last_embedding) |
logits = output.logits[:, -1, :] |
if top_p == 1.0: |
logits = logits.cpu() |
output_logits.append(logits) |
if self.retrieval_token_idx != -1 and self.retrieval_token_idx is not None: |
if i < min_word_tokens: |
logits[:, self.retrieval_token_idx] = filter_value |
else: |
logits[:, self.retrieval_token_idx] = logits[:, self.retrieval_token_idx] * ret_scale_factor |
past_key_values = output.past_key_values |
if temperature == 0.0: |
if top_p != 1.0: |
raise ValueError('top_p cannot be set if temperature is 0 (greedy decoding).') |
next_token = torch.argmax(logits, keepdim=True, dim=-1) |
else: |
logits = logits / temperature |
if top_p < 1.0: |
assert top_p > 0, f'top_p should be above 0, got {top_p} instead.' |
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
sorted_indices_to_remove = cumulative_probs > top_p |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
sorted_indices_to_remove[..., 0] = 0 |
for j in range(sorted_indices.shape[0]): |
indices_to_remove = sorted_indices[j, sorted_indices_to_remove[j, :]] |
logits[j, indices_to_remove] = filter_value |
token_weights = logits.exp() |
next_token = torch.multinomial(token_weights, 1) |
next_token = next_token.long().to(embeddings.device) |
if out is not None: |
out = torch.cat([out, next_token], dim=-1) |
else: |
out = next_token |
if 'opt' in self.opt_version: |
next_embedding = self.input_embeddings(next_token) |
embeddings = torch.cat([embeddings, next_embedding], dim=1) |
elif (self.tokenizer.eos_token_id and (next_token == self.tokenizer.eos_token_id).all()): |
break |
return out, output_embeddings, output_logits |
class Fromage(nn.Module): |
def __init__(self, tokenizer, model_args: Optional[FrozenArgs] = None, |
path_array: Optional[List[str]] = None, emb_matrix: Optional[torch.tensor] = None): |
super().__init__() |
self.model = FromageModel(tokenizer, model_args) |
self.path_array = path_array |
self.emb_matrix = emb_matrix |
def __call__(self, images: Tensor, tgt_tokens: Optional[Tensor] = None, caption_len: Optional[Tensor] = None, |
generate: bool = False, num_words: int = 32, temperature: float = 1.0, top_p: float = 1.0, |
ret_scale_factor: float = 1.0, min_word_tokens: int = 0, |
mode: str = 'captioning', concat_captions: bool = False, |
input_prefix: Optional[str] = None, inference: bool = False) -> Tensor: |
if generate: |
return self.model.generate(images, num_words, temperature=temperature, top_p=top_p, |
min_word_tokens=min_word_tokens, ret_scale_factor=ret_scale_factor) |
else: |
output = self.model( |
pixel_values = images, |
labels = tgt_tokens, |
caption_len = caption_len, |
mode = mode, |
concat_captions = concat_captions, |
input_prefix = input_prefix, |
inference = inference) |
return output |
def generate_for_images_and_texts( |
self, prompts: List, num_words: int = 0, ret_scale_factor: float = 1.0, top_p: float = 1.0, temperature: float = 0.0, |
max_num_rets: int = 1): |
""" |
Encode prompts into embeddings. |
Args: |
prompts: List of interleaved PIL.Image.Image and strings representing input to the model. |
num_words: Maximum number of words to generate for. If num_words = 0, the model will run its forward pass and return the outputs. |
ret_scale_factor: Proportion to scale [RET] token logits by. A higher value may increase the probability of the model generating [RET] outputs. |
top_p: If set to < 1, the smallest set of tokens with highest probabilities that add up to top_p or higher are kept for generation. |
temperature: Used to modulate logit distribution. |
max_num_rets: Maximum number of images to return in one generation pass. |
Returns: |
return_outputs: List consisting of either str or List[PIL.Image.Image] objects, representing image-text interleaved model outputs. |
""" |
input_embs = [] |
input_ids = [] |
add_bos = True |
for i, p in enumerate(prompts): |
if type(p) == Image.Image: |
pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, p) |
pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype) |
pixel_values = pixel_values[None, ...] |
visual_embs = self.model.get_visual_embs(pixel_values, mode='captioning') |
input_embs.append(visual_embs) |
elif type(p) == str: |
text_ids = self.model.tokenizer(p, add_special_tokens=True, return_tensors="pt").input_ids.to(self.model.logit_scale.device) |
if not add_bos: |
text_ids = text_ids[:, 1:] |
else: |
add_bos = False |
text_embs = self.model.input_embeddings(text_ids) |
input_embs.append(text_embs) |
input_ids.append(text_ids) |
else: |
raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.') |
input_embs = torch.cat(input_embs, dim=1) |
input_ids = torch.cat(input_ids, dim=1) |
if num_words == 0: |
generated_ids = input_ids |
outputs = self.model.lm(inputs_embeds=input_embs, use_cache=False, output_hidden_states=True) |
out = [] |
for x, fc in zip(self.model.args.text_emb_layers, self.model.text_hidden_fcs): |
out.append(fc(outputs.hidden_states[x])) |
embeddings = torch.stack(out, dim=-1).sum(dim=-1) |
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) |
elif num_words > 0: |
generated_ids, generated_embeddings, _ = self.model.generate(input_embs, num_words, |
temperature=temperature, top_p=top_p, ret_scale_factor=ret_scale_factor) |
embeddings = generated_embeddings[-1][:, input_embs.shape[1]:] |
newline_token_id = self.model.tokenizer('\n', add_special_tokens=False).input_ids[0] |
trunc_idx = 0 |
for j in range(generated_ids.shape[1]): |
if generated_ids[0, j] == newline_token_id: |
trunc_idx = j |
break |
if trunc_idx > 0: |
generated_ids = generated_ids[:, :trunc_idx] |
embeddings = embeddings[:, :trunc_idx] |
else: |
raise ValueError |
return_outputs = [] |
all_ret_idx = [i for i, x in enumerate(generated_ids[0, :] == self.model.retrieval_token_idx) if x][:max_num_rets] |
seen_image_idx = [] |
last_ret_idx = 0 |
if len(all_ret_idx) == 0: |
caption = self.model.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
return_outputs.append(utils.truncate_caption(caption)) |
else: |
for ret_idx in all_ret_idx: |
ret_emb = embeddings[:, ret_idx, :] |
scores = self.emb_matrix @ ret_emb.T |
for seen_idx in seen_image_idx: |
scores[seen_idx, :] -= 1000 |
_, top_image_idx = scores.squeeze().topk(3) |
image_outputs = [] |
for img_idx in top_image_idx: |
try: |
seen_image_idx.append(img_idx) |
img = utils.get_image_from_url(self.path_array[img_idx]) |
image_outputs.append(img) |
if len(image_outputs) == max_num_rets: |
break |
except UnidentifiedImageError: |
pass |
caption = self.model.tokenizer.batch_decode(generated_ids[:, last_ret_idx:ret_idx], skip_special_tokens=True)[0] |
last_ret_idx = ret_idx + 1 |
return_outputs.append(utils.truncate_caption(caption) + ' [RET]') |
return_outputs.append(image_outputs) |
return return_outputs |
def load_fromage(model_dir: str, ckpt_path: str) -> Fromage: |
model_args_path = os.path.join(model_dir, 'model_args.json') |
model_ckpt_path = os.path.join(ckpt_path) |
embs_paths = [s for s in glob.glob(os.path.join(model_dir, 'cc3m_embeddings*.pkl'))] |
if not os.path.exists(model_args_path): |
raise ValueError(f'model_args.json does not exist in {model_dir}.') |
if not os.path.exists(model_ckpt_path): |
raise ValueError(f'pretrained_ckpt.pth.tar does not exist in {model_dir}.') |
if len(embs_paths) == 0: |
raise ValueError(f'cc3m_embeddings_*.pkl files do not exist in {model_dir}.') |
path_array = [] |
emb_matrix = [] |
for p in embs_paths: |
with open(p, 'rb') as wf: |
train_embs_data = pkl.load(wf) |
path_array.extend(train_embs_data['paths']) |
emb_matrix.append(train_embs_data['embeddings']) |
emb_matrix = np.concatenate(emb_matrix, axis=0) |
assert len(path_array) == emb_matrix.shape[0], (len(path_array), emb_matrix.shape[0]) |
with open(model_args_path, 'r') as f: |
model_kwargs = json.load(f) |
tokenizer = GPT2Tokenizer.from_pretrained(model_kwargs['opt_version']) |
tokenizer.pad_token = tokenizer.eos_token |
tokenizer.add_special_tokens({"cls_token": "<|image|>"}) |
tokenizer.add_tokens('[RET]') |
ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids |
assert len(ret_token_idx) == 1, ret_token_idx |
model_kwargs['retrieval_token_idx'] = ret_token_idx[0] |
args = namedtuple('args', model_kwargs)(**model_kwargs) |
model = Fromage(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix) |
model = model.eval() |
model = model.bfloat16() |
model = model.cuda() |
checkpoint = torch.load(model_ckpt_path) |
model.load_state_dict(checkpoint['state_dict'], strict=False) |
with torch.no_grad(): |
model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].cpu().detach()) |
logit_scale = model.model.logit_scale.exp() |
emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device) |
emb_matrix = emb_matrix / emb_matrix.norm(dim=1, keepdim=True) |
emb_matrix = logit_scale * emb_matrix |
model.emb_matrix = emb_matrix |
return model |