|
from PIL import Image |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import transformers |
|
from transformers import PreTrainedModel |
|
|
|
from src import loss |
|
from src import vision_model |
|
from src.config import TinyCLIPConfig |
|
from src.config import TinyCLIPTextConfig |
|
from src.config import TinyCLIPVisionConfig |
|
|
|
|
|
class Projection(nn.Module): |
|
def __init__(self, d_in: int, d_out: int, p: float = 0.5) -> None: |
|
super().__init__() |
|
self.linear1 = nn.Linear(d_in, d_out, bias=False) |
|
self.linear2 = nn.Linear(d_out, d_out, bias=False) |
|
self.layer_norm = nn.LayerNorm(d_out) |
|
self.drop = nn.Dropout(p) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
embed1 = self.linear1(x) |
|
embed2 = self.drop(self.linear2(F.gelu(embed1))) |
|
embeds = self.layer_norm(embed1 + embed2) |
|
return embeds |
|
|
|
|
|
def projection_layers(d_in: int, d_out: int, num_layers: int) -> nn.Module: |
|
layers = [] |
|
for _ in range(num_layers - 1): |
|
layers.extend([Projection(d_in, d_in), nn.GELU()]) |
|
layers += [Projection(d_in, d_out)] |
|
return nn.Sequential(*layers) |
|
|
|
|
|
def mean_pooling( |
|
text_representation: torch.FloatTensor, attention_mask: torch.LongTensor |
|
) -> torch.FloatTensor: |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(text_representation.size()).float() |
|
return torch.sum(text_representation * input_mask_expanded, 1) / torch.clamp( |
|
input_mask_expanded.sum(1), min=1e-9 |
|
) |
|
|
|
|
|
class TinyCLIPTextEncoder(PreTrainedModel): |
|
config_class = TinyCLIPTextConfig |
|
|
|
def __init__(self, config: TinyCLIPTextConfig): |
|
super().__init__(config) |
|
self.base = transformers.AutoModel.from_pretrained(config.text_model) |
|
self.cls_type = config.cls_type |
|
self.projection = projection_layers( |
|
self.base.config.hidden_size, config.embed_dims, config.projection_layers |
|
) |
|
|
|
def forward(self, x: dict[str, torch.Tensor]): |
|
out = self.base(**x).last_hidden_state |
|
if self.cls_type: |
|
out = out[:, 0] |
|
else: |
|
out = mean_pooling(out, x["attention_mask"]) |
|
|
|
projected_vec = self.projection(out) |
|
return F.normalize(projected_vec, dim=-1) |
|
|
|
|
|
class TinyCLIPVisionEncoder(PreTrainedModel): |
|
config_class = TinyCLIPVisionConfig |
|
|
|
def __init__(self, config: TinyCLIPVisionConfig): |
|
super().__init__(config) |
|
base, num_features = vision_model.get_vision_base(config) |
|
self.base = base |
|
self.projection = projection_layers( |
|
num_features, config.embed_dims, config.projection_layers |
|
) |
|
|
|
def forward(self, images: torch.Tensor): |
|
projected_vec = self.projection(self.base(images)) |
|
return F.normalize(projected_vec, dim=-1) |
|
|
|
|
|
class TinyCLIP(PreTrainedModel): |
|
config_class = TinyCLIPConfig |
|
|
|
def __init__(self, config: TinyCLIPConfig): |
|
super().__init__(config) |
|
self.text_encoder = TinyCLIPTextEncoder(config.text_config) |
|
self.vision_encoder = TinyCLIPVisionEncoder(config.vision_config) |
|
|
|
if config.freeze_text_base: |
|
self.text_encoder.base.eval() |
|
for param in self.text_encoder.parameters(): |
|
param.requires_grad = False |
|
|
|
if config.freeze_vision_base: |
|
self.vision_encoder.base.eval() |
|
for param in self.vision_encoder.parameters(): |
|
param.requires_grad = False |
|
|
|
self.loss_fn = loss.get_loss(config.loss_type) |
|
|
|
def forward( |
|
self, |
|
text_input: dict[str, torch.Tensor], |
|
vision_input: list[Image.Image], |
|
return_loss: bool = False, |
|
) -> dict[str, torch.Tensor]: |
|
text_output = self.text_encoder(text_input) |
|
vision_output = self.vision_encoder(vision_input) |
|
|
|
out = {"text_output": text_output, "vision_output": vision_output} |
|
|
|
if return_loss: |
|
out["loss"] = self.loss_fn(vision_output, text_output) |
|
|
|
return out |
|
|
|
|
|
if __name__ == "__main__": |
|
model = TinyCLIP(TinyCLIPConfig()) |
|
print(model) |
|
print("Done!") |
|
|