GunaKoppula's picture
Upload 19 files
a8d9c50 verified
raw
history blame
3.04 kB
from abc import ABC
import torch
import torch.nn as nn
from transformers import CLIPPreTrainedModel, CLIPVisionConfig
from transformers.models.clip.modeling_clip import CLIPVisionTransformer
from inference.model.language_model.configuration_llava_phi import LlavaPhiVisionConfig
class CLIPVisionTower(CLIPPreTrainedModel):
config_class = LlavaPhiVisionConfig
def __init__(self, config):
super().__init__(config)
self.vision_model = CLIPVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[
self.config.mm_vision_select_layer
]
if self.config.mm_vision_select_feature == "patch":
image_features = image_features[:, 1:]
elif self.config.mm_vision_select_feature == "cls_patch":
image_features = image_features
else:
raise ValueError(
f"Unexpected select feature: {self.config.mm_vision_select_feature}"
)
return image_features
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_model(
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
output_hidden_states=True,
)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_model(
images.to(device=self.device, dtype=self.dtype),
output_hidden_states=True,
)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return list(self.vision_model.parameters())[0].dtype
@property
def device(self):
return list(self.vision_model.parameters())[0].device
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
if __name__ == "__main__":
clip_config = CLIPVisionConfig.from_pretrained(
"/data/private/zhumj/GPTcode/mm-phi/openai/clip-vit-large-patch14-336"
)
print("################ clip_config ##############")
print(clip_config)
phi_vis_config = LlavaPhiVisionConfig(**clip_config.to_dict())
print("################ phi_vis_config ##############")
print(phi_vis_config)
model = CLIPVisionTower(clip_config)
# print(list(model.vision_model.parameters())[0].dtype)