InstructIR / text /models.py
mv-lab
InstructIR x HF
39417b0
raw
history blame
2.62 kB
import torch
from torch import nn
import torch.nn.functional as F
from transformers import DistilBertModel, DistilBertTokenizer, AutoModel, AutoTokenizer
import os
# Models that use mean pooling
POOL_MODELS = {"sentence-transformers/all-MiniLM-L6-v2", "TaylorAI/bge-micro-v2"}
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
class LanguageModel(nn.Module):
def __init__(self, model='distilbert-base-uncased'):
super(LanguageModel, self).__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.model = AutoModel.from_pretrained(model)
self.model_name = model
# Remove the CLIP vision tower
if "clip" in self.model_name:
self.model.vision_model = None
# Freeze the pre-trained parameters (very important)
for param in self.model.parameters():
param.requires_grad = False
# Make sure to set evaluation mode (also important)
self.model.eval()
def forward(self, text_batch):
inputs = self.tokenizer(text_batch, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad(): # Ensure no gradients are computed for this forward pass
if "clip" in self.model_name:
sentence_embedding = self.model.get_text_features(**inputs)
return sentence_embedding
outputs = self.model(**inputs)
if any(model in self.model_name for model in POOL_MODELS):
sentence_embeddings = mean_pooling(outputs, inputs['attention_mask'])
# Normalize embeddings
sentence_embedding = F.normalize(sentence_embeddings, p=2, dim=1)
else:
sentence_embedding = outputs.last_hidden_state[:, 0, :]
return sentence_embedding
class LMHead(nn.Module):
def __init__(self, embedding_dim=384, hidden_dim=256, num_classes=4):
super(LMHead, self).__init__()
self.fc1 = nn.Linear(embedding_dim, hidden_dim)
#self.gelu = nn.GELU()
self.fc2 = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
embd = self.fc1(x)
embd = F.normalize(embd, p=2, dim=1)
deg_pred = self.fc2(embd)
return embd, deg_pred