|
import torch |
|
import torch.nn as nn |
|
from transformers import CLIPModel |
|
|
|
class VariableLengthCLIP(nn.Module): |
|
def __init__(self, clip_model, num_classes): |
|
super().__init__() |
|
self.clip_model = clip_model |
|
self.visual_projection = nn.Linear(clip_model.visual_projection.in_features, num_classes) |
|
|
|
def forward(self, x): |
|
batch_size, num_frames, c, h, w = x.shape |
|
x = x.view(batch_size * num_frames, c, h, w) |
|
features = self.clip_model.vision_model(x).pooler_output |
|
features = features.view(batch_size, num_frames, -1) |
|
features = torch.mean(features, dim=1) |
|
return self.visual_projection(features) |
|
|
|
def unfreeze_vision_encoder(self, num_layers=2): |
|
|
|
for param in self.clip_model.vision_model.parameters(): |
|
param.requires_grad = False |
|
|
|
for param in self.clip_model.vision_model.encoder.layers[-num_layers:].parameters(): |
|
param.requires_grad = True |
|
|
|
def create_model(num_classes, pretrained_model_name="openai/clip-vit-base-patch32"): |
|
clip_model = CLIPModel.from_pretrained(pretrained_model_name) |
|
return VariableLengthCLIP(clip_model, num_classes) |
|
|
|
def load_model(num_classes, model_path, device, pretrained_model_name="openai/clip-vit-base-patch32"): |
|
|
|
model = create_model(num_classes, pretrained_model_name) |
|
|
|
|
|
state_dict = torch.load(model_path, map_location=device, weights_only=True) |
|
|
|
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
model.to(device) |
|
return model |
|
|