nanamma commited on
Commit
03a425d
1 Parent(s): d64a547
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets:
3
+ - OpenGVLab/InternVid
4
+ base_model:
5
+ - openai/clip-vit-base-patch16
6
+ tags:
7
+ - ViCLIP
8
+ ---
9
+
10
+ huggingface weight of ViCLIP
11
+
12
+ remember to set your `tokenizer_path` in config.json
13
+
14
+ usage is in demo.ipynb
bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ViCLIP"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_viclip.Config",
7
+ "AutoModel": "viclip.ViCLIP"
8
+ },
9
+ "torch_dtype": "float32",
10
+ "size":"b",
11
+ "tokenizer_path":"./bpe_simple_vocab_16e6.txt.gz"
12
+ }
configuration_viclip.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class Config(PretrainedConfig):
4
+ def __init__(self, **kwargs):
5
+ super().__init__(**kwargs)
demo.ipynb ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "a436c0a1-3410-4a7f-a186-9246075ac815",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from transformers import AutoModel\n",
11
+ "model=AutoModel.from_pretrained(\"OpenGVLab/ViCLIP-B-16-hf\",trust_remote_code=True)\n",
12
+ "tokenizer = model.tokenizer\n",
13
+ "model_tokenizer={\"viclip\":model,\"tokenizer\":tokenizer}\n",
14
+ "print(\"done\")"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 3,
20
+ "id": "a425a5da-ceaf-4b89-9845-c8ba576902d8",
21
+ "metadata": {},
22
+ "outputs": [],
23
+ "source": [
24
+ "# video data\n",
25
+ "import numpy as np\n",
26
+ "import os\n",
27
+ "import cv2\n",
28
+ "import torch\n",
29
+ "def _frame_from_video(video):\n",
30
+ " while video.isOpened():\n",
31
+ " success, frame = video.read()\n",
32
+ " if success:\n",
33
+ " yield frame\n",
34
+ " else:\n",
35
+ " break\n",
36
+ "video = cv2.VideoCapture('example1.mp4')\n",
37
+ "frames = [x for x in _frame_from_video(video)]"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": 5,
43
+ "id": "aac775ce",
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "# function\n",
48
+ "\n",
49
+ "def get_text_feat_dict(texts, clip, tokenizer, text_feat_d={}):\n",
50
+ " for t in texts:\n",
51
+ " feat = clip.get_text_features(t, tokenizer, text_feat_d)\n",
52
+ " text_feat_d[t] = feat\n",
53
+ " return text_feat_d\n",
54
+ "\n",
55
+ "def get_vid_feat(frames, clip):\n",
56
+ " return clip.get_vid_features(frames)\n",
57
+ "\n",
58
+ "v_mean = np.array([0.485, 0.456, 0.406]).reshape(1,1,3)\n",
59
+ "v_std = np.array([0.229, 0.224, 0.225]).reshape(1,1,3)\n",
60
+ "def normalize(data):\n",
61
+ " return (data/255.0-v_mean)/v_std\n",
62
+ "\n",
63
+ "def frames2tensor(vid_list, fnum=8, target_size=(224, 224), device=torch.device('cuda')):\n",
64
+ " assert(len(vid_list) >= fnum)\n",
65
+ " step = len(vid_list) // fnum\n",
66
+ " vid_list = vid_list[::step][:fnum]\n",
67
+ " vid_list = [cv2.resize(x[:,:,::-1], target_size) for x in vid_list]\n",
68
+ " vid_tube = [np.expand_dims(normalize(x), axis=(0, 1)) for x in vid_list]\n",
69
+ " vid_tube = np.concatenate(vid_tube, axis=1)\n",
70
+ " vid_tube = np.transpose(vid_tube, (0, 1, 4, 2, 3))\n",
71
+ " vid_tube = torch.from_numpy(vid_tube).to(device, non_blocking=True).float()\n",
72
+ " return vid_tube\n",
73
+ "def retrieve_text(frames, \n",
74
+ " texts, \n",
75
+ " models={'viclip':None, \n",
76
+ " 'tokenizer':None},\n",
77
+ " topk=5, \n",
78
+ " device=torch.device('cuda')):\n",
79
+ " # clip, tokenizer = get_clip(name, model_cfg['size'], model_cfg['pretrained'], model_cfg['reload'])\n",
80
+ " assert(type(models)==dict and models['viclip'] is not None and models['tokenizer'] is not None)\n",
81
+ " clip, tokenizer = models['viclip'], models['tokenizer']\n",
82
+ " clip = clip.to(device)\n",
83
+ " frames_tensor = frames2tensor(frames, device=device)\n",
84
+ " vid_feat = get_vid_feat(frames_tensor, clip)\n",
85
+ "\n",
86
+ " text_feat_d = {}\n",
87
+ " text_feat_d = get_text_feat_dict(texts, clip, tokenizer, text_feat_d)\n",
88
+ " text_feats = [text_feat_d[t] for t in texts]\n",
89
+ " text_feats_tensor = torch.cat(text_feats, 0)\n",
90
+ " \n",
91
+ " probs, idxs = clip.get_predict_label(vid_feat, text_feats_tensor, top=topk)\n",
92
+ "\n",
93
+ " ret_texts = [texts[i] for i in idxs.numpy()[0].tolist()]\n",
94
+ " return ret_texts, probs.numpy()[0]"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": null,
100
+ "id": "a2969ba6-19d0-4893-b071-b82fa046c312",
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "# retrieval\n",
105
+ "text_candidates = [\"A playful dog and its owner wrestle in the snowy yard, chasing each other with joyous abandon.\",\n",
106
+ " \"A man in a gray coat walks through the snowy landscape, pulling a sleigh loaded with toys.\",\n",
107
+ " \"A person dressed in a blue jacket shovels the snow-covered pavement outside their house.\",\n",
108
+ " \"A pet dog excitedly runs through the snowy yard, chasing a toy thrown by its owner.\",\n",
109
+ " \"A person stands on the snowy floor, pushing a sled loaded with blankets, preparing for a fun-filled ride.\",\n",
110
+ " \"A man in a gray hat and coat walks through the snowy yard, carefully navigating around the trees.\",\n",
111
+ " \"A playful dog slides down a snowy hill, wagging its tail with delight.\",\n",
112
+ " \"A person in a blue jacket walks their pet on a leash, enjoying a peaceful winter walk among the trees.\",\n",
113
+ " \"A man in a gray sweater plays fetch with his dog in the snowy yard, throwing a toy and watching it run.\",\n",
114
+ " \"A person bundled up in a blanket walks through the snowy landscape, enjoying the serene winter scenery.\"]\n",
115
+ "texts, probs = retrieve_text(frames, text_candidates, models=model_tokenizer, topk=5)\n",
116
+ "\n",
117
+ "for t, p in zip(texts, probs):\n",
118
+ " print(f'text: {t} ~ prob: {p:.4f}')\n",
119
+ " \n",
120
+ "\n",
121
+ "# text: A playful dog and its owner wrestle in the snowy yard, chasing each other with joyous abandon. ~ prob: 0.8192\n",
122
+ "# text: A man in a gray sweater plays fetch with his dog in the snowy yard, throwing a toy and watching it run. ~ prob: 0.1084\n",
123
+ "# text: A pet dog excitedly runs through the snowy yard, chasing a toy thrown by its owner. ~ prob: 0.0676\n",
124
+ "# text: A playful dog slides down a snowy hill, wagging its tail with delight. ~ prob: 0.0047\n",
125
+ "# text: A person dressed in a blue jacket shovels the snow-covered pavement outside their house. ~ prob: 0.0002"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "id": "84922de7-b41c-41c1-87a0-b28e52da9b5d",
132
+ "metadata": {},
133
+ "outputs": [],
134
+ "source": []
135
+ }
136
+ ],
137
+ "metadata": {
138
+ "kernelspec": {
139
+ "display_name": "Python 3 (ipykernel)",
140
+ "language": "python",
141
+ "name": "python3"
142
+ },
143
+ "language_info": {
144
+ "codemirror_mode": {
145
+ "name": "ipython",
146
+ "version": 3
147
+ },
148
+ "file_extension": ".py",
149
+ "mimetype": "text/x-python",
150
+ "name": "python",
151
+ "nbconvert_exporter": "python",
152
+ "pygments_lexer": "ipython3",
153
+ "version": "3.10.4"
154
+ }
155
+ },
156
+ "nbformat": 4,
157
+ "nbformat_minor": 5
158
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af9cfba8b30a4d62fec6bf7a033f748514ae86d04bd53b3a47f17dd4c7af2741
3
+ size 598452684
simple_tokenizer.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+ # @lru_cache()
14
+ # def default_bpe():
15
+ # return "bpe_simple_vocab_16e6.txt.gz"
16
+
17
+
18
+ @lru_cache()
19
+ def bytes_to_unicode():
20
+ """
21
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
22
+ The reversible bpe codes work on unicode strings.
23
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
24
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
25
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
26
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
27
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
28
+ """
29
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
30
+ cs = bs[:]
31
+ n = 0
32
+ for b in range(2**8):
33
+ if b not in bs:
34
+ bs.append(b)
35
+ cs.append(2**8+n)
36
+ n += 1
37
+ cs = [chr(n) for n in cs]
38
+ return dict(zip(bs, cs))
39
+
40
+
41
+ def get_pairs(word):
42
+ """Return set of symbol pairs in a word.
43
+ Word is represented as tuple of symbols (symbols being variable-length strings).
44
+ """
45
+ pairs = set()
46
+ prev_char = word[0]
47
+ for char in word[1:]:
48
+ pairs.add((prev_char, char))
49
+ prev_char = char
50
+ return pairs
51
+
52
+
53
+ def basic_clean(text):
54
+ text = ftfy.fix_text(text)
55
+ text = html.unescape(html.unescape(text))
56
+ return text.strip()
57
+
58
+
59
+ def whitespace_clean(text):
60
+ text = re.sub(r'\s+', ' ', text)
61
+ text = text.strip()
62
+ return text
63
+
64
+
65
+ class SimpleTokenizer(object):
66
+ def __init__(self, bpe_path: str = default_bpe()):
67
+ self.byte_encoder = bytes_to_unicode()
68
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
69
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
70
+ merges = merges[1:49152-256-2+1]
71
+ merges = [tuple(merge.split()) for merge in merges]
72
+ vocab = list(bytes_to_unicode().values())
73
+ vocab = vocab + [v+'</w>' for v in vocab]
74
+ for merge in merges:
75
+ vocab.append(''.join(merge))
76
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
77
+ self.encoder = dict(zip(vocab, range(len(vocab))))
78
+ self.decoder = {v: k for k, v in self.encoder.items()}
79
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
80
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
81
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
82
+
83
+ def bpe(self, token):
84
+ if token in self.cache:
85
+ return self.cache[token]
86
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
87
+ pairs = get_pairs(word)
88
+
89
+ if not pairs:
90
+ return token+'</w>'
91
+
92
+ while True:
93
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
94
+ if bigram not in self.bpe_ranks:
95
+ break
96
+ first, second = bigram
97
+ new_word = []
98
+ i = 0
99
+ while i < len(word):
100
+ try:
101
+ j = word.index(first, i)
102
+ new_word.extend(word[i:j])
103
+ i = j
104
+ except:
105
+ new_word.extend(word[i:])
106
+ break
107
+
108
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
109
+ new_word.append(first+second)
110
+ i += 2
111
+ else:
112
+ new_word.append(word[i])
113
+ i += 1
114
+ new_word = tuple(new_word)
115
+ word = new_word
116
+ if len(word) == 1:
117
+ break
118
+ else:
119
+ pairs = get_pairs(word)
120
+ word = ' '.join(word)
121
+ self.cache[token] = word
122
+ return word
123
+
124
+ def encode(self, text):
125
+ bpe_tokens = []
126
+ text = whitespace_clean(basic_clean(text)).lower()
127
+ for token in re.findall(self.pat, text):
128
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
129
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
130
+ return bpe_tokens
131
+
132
+ def decode(self, tokens):
133
+ text = ''.join([self.decoder[token] for token in tokens])
134
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
135
+ return text
viclip.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from torch import nn
7
+ import math
8
+
9
+ # from .criterions import VTC_VTM_Loss
10
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
11
+ from .viclip_vision import clip_joint_l14, clip_joint_b16
12
+ from .viclip_text import clip_text_l14, clip_text_b16
13
+
14
+ # from transformers import AutoModel
15
+ from transformers import PreTrainedModel #new
16
+ from transformers import PretrainedConfig
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ from .configuration_viclip import Config
21
+ # class ViCLIP(nn.Module):
22
+ class ViCLIP(PreTrainedModel):
23
+ _auto_class="AutoModel"
24
+ config_class=Config
25
+
26
+ def __init__(self,
27
+ # tokenizer=None, # config:PretrainedConfig is the only parameter
28
+ # size='l',
29
+ # pretrain=None,
30
+ # freeze_text=True,
31
+ config=PretrainedConfig()):
32
+ super(ViCLIP, self).__init__(config)
33
+ self.config=config
34
+ if 'size' in config.to_dict(): ###########
35
+ size=config.size
36
+ pretrain=None
37
+ tokenizer_path=config.tokenizer_path
38
+ tokenizer=None
39
+ freeze_text=True
40
+
41
+ if tokenizer:
42
+ self.tokenizer = tokenizer
43
+ elif tokenizer_path:
44
+ self.tokenizer = _Tokenizer(tokenizer_path)
45
+ else:
46
+ self.tokenizer = _Tokenizer()
47
+ self.max_txt_l = 32
48
+
49
+ if size.lower() == 'l':
50
+ self.vision_encoder_name = 'vit_l14'
51
+ elif size.lower() == 'b':
52
+ self.vision_encoder_name = 'vit_b16'
53
+ else:
54
+ raise NotImplementedError(f"Size {size} not implemented")
55
+
56
+ self.vision_encoder_pretrained = False
57
+ self.inputs_image_res = 224
58
+ self.vision_encoder_kernel_size = 1
59
+ self.vision_encoder_center = True
60
+ self.video_input_num_frames = 8
61
+ self.vision_encoder_drop_path_rate = 0.1
62
+ self.vision_encoder_checkpoint_num = 24
63
+ self.is_pretrain = pretrain
64
+ self.vision_width = 1024
65
+ self.text_width = 768
66
+ self.embed_dim = 768
67
+ self.masking_prob = 0.9
68
+
69
+ if size.lower() == 'l':
70
+ self.text_encoder_name = 'vit_l14'
71
+ elif size.lower() == 'b':
72
+ self.text_encoder_name = 'vit_b16'
73
+ else:
74
+ raise NotImplementedError(f"Size {size} not implemented")
75
+
76
+ self.text_encoder_pretrained = False#'bert-base-uncased'
77
+ self.text_encoder_d_model = 768
78
+
79
+ self.text_encoder_vocab_size = 49408
80
+
81
+ # create modules.
82
+ self.vision_encoder = self.build_vision_encoder()
83
+ self.text_encoder = self.build_text_encoder()
84
+
85
+ self.temp = nn.parameter.Parameter(torch.ones([]) * 1 / 100.0)
86
+ self.temp_min = 1 / 100.0
87
+
88
+ if pretrain:
89
+ logger.info(f"Load pretrained weights from {pretrain}")
90
+ state_dict = torch.load(pretrain, map_location='cpu')['model']
91
+ self.load_state_dict(state_dict)
92
+
93
+ # Freeze weights
94
+ if freeze_text:
95
+ self.freeze_text()
96
+
97
+
98
+ def freeze_text(self):
99
+ """freeze text encoder"""
100
+ for p in self.text_encoder.parameters():
101
+ p.requires_grad = False
102
+
103
+ def no_weight_decay(self):
104
+ ret = {"temp"}
105
+ ret.update(
106
+ {"vision_encoder." + k for k in self.vision_encoder.no_weight_decay()}
107
+ )
108
+ ret.update(
109
+ {"text_encoder." + k for k in self.text_encoder.no_weight_decay()}
110
+ )
111
+
112
+ return ret
113
+
114
+ def forward(self, image, text, raw_text, idx, log_generation=None, return_sims=False):
115
+ """forward and calculate loss.
116
+
117
+ Args:
118
+ image (torch.Tensor): The input images. Shape: [B,T,C,H,W].
119
+ text (dict): TODO
120
+ idx (torch.Tensor): TODO
121
+
122
+ Returns: TODO
123
+
124
+ """
125
+ self.clip_contrastive_temperature()
126
+
127
+ vision_embeds = self.encode_vision(image)
128
+ text_embeds = self.encode_text(raw_text)
129
+ if return_sims:
130
+ sims = torch.nn.functional.normalize(vision_embeds, dim=-1) @ \
131
+ torch.nn.functional.normalize(text_embeds, dim=-1).transpose(0, 1)
132
+ return sims
133
+
134
+ # calculate loss
135
+
136
+ ## VTC loss
137
+ loss_vtc = self.clip_loss.vtc_loss(
138
+ vision_embeds, text_embeds, idx, self.temp, all_gather=True
139
+ )
140
+
141
+ return dict(
142
+ loss_vtc=loss_vtc,
143
+ )
144
+
145
+ def encode_vision(self, image, test=False):
146
+ """encode image / videos as features.
147
+
148
+ Args:
149
+ image (torch.Tensor): The input images.
150
+ test (bool): Whether testing.
151
+
152
+ Returns: tuple.
153
+ - vision_embeds (torch.Tensor): The features of all patches. Shape: [B,T,L,C].
154
+ - pooled_vision_embeds (torch.Tensor): The pooled features. Shape: [B,T,C].
155
+
156
+ """
157
+ if image.ndim == 5:
158
+ image = image.permute(0, 2, 1, 3, 4).contiguous()
159
+ else:
160
+ image = image.unsqueeze(2)
161
+
162
+ if not test and self.masking_prob > 0.0:
163
+ return self.vision_encoder(
164
+ image, masking_prob=self.masking_prob
165
+ )
166
+
167
+ return self.vision_encoder(image)
168
+
169
+ def encode_text(self, text):
170
+ """encode text.
171
+ Args:
172
+ text (dict): The output of huggingface's `PreTrainedTokenizer`. contains keys:
173
+ - input_ids (torch.Tensor): Token ids to be fed to a model. Shape: [B,L].
174
+ - attention_mask (torch.Tensor): The mask indicate padded tokens. Shape: [B,L]. 0 is padded token.
175
+ - other keys refer to "https://huggingface.co/docs/transformers/v4.21.2/en/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__".
176
+ Returns: tuple.
177
+ - text_embeds (torch.Tensor): The features of all tokens. Shape: [B,L,C].
178
+ - pooled_text_embeds (torch.Tensor): The pooled features. Shape: [B,C].
179
+
180
+ """
181
+ device = next(self.text_encoder.parameters()).device
182
+ text = self.text_encoder.tokenize(
183
+ text, context_length=self.max_txt_l
184
+ ).to(device)
185
+ text_embeds = self.text_encoder(text)
186
+ return text_embeds
187
+
188
+ @torch.no_grad()
189
+ def clip_contrastive_temperature(self, min_val=0.001, max_val=0.5):
190
+ """Seems only used during pre-training"""
191
+ self.temp.clamp_(min=self.temp_min)
192
+
193
+ def build_vision_encoder(self):
194
+ """build vision encoder
195
+ Returns: (vision_encoder, vision_layernorm). Each is a `nn.Module`.
196
+
197
+ """
198
+ encoder_name = self.vision_encoder_name
199
+ if encoder_name == "vit_l14":
200
+ vision_encoder = clip_joint_l14(
201
+ pretrained=self.vision_encoder_pretrained,
202
+ input_resolution=self.inputs_image_res,
203
+ kernel_size=self.vision_encoder_kernel_size,
204
+ center=self.vision_encoder_center,
205
+ num_frames=self.video_input_num_frames,
206
+ drop_path=self.vision_encoder_drop_path_rate,
207
+ checkpoint_num=self.vision_encoder_checkpoint_num,
208
+ )
209
+ elif encoder_name == "vit_b16":
210
+ vision_encoder = clip_joint_b16(
211
+ pretrained=self.vision_encoder_pretrained,
212
+ input_resolution=self.inputs_image_res,
213
+ kernel_size=self.vision_encoder_kernel_size,
214
+ center=self.vision_encoder_center,
215
+ num_frames=self.video_input_num_frames,
216
+ drop_path=self.vision_encoder_drop_path_rate,
217
+ checkpoint_num=self.vision_encoder_checkpoint_num,
218
+ )
219
+ else:
220
+ raise NotImplementedError(f"Not implemented: {encoder_name}")
221
+
222
+ return vision_encoder
223
+
224
+ def build_text_encoder(self):
225
+ """build text_encoder and possiblly video-to-text multimodal fusion encoder.
226
+ Returns: nn.Module. The text encoder
227
+
228
+ """
229
+ encoder_name = self.text_encoder_name
230
+
231
+ if encoder_name == "vit_l14":
232
+ text_encoder = clip_text_l14(
233
+ pretrained=self.text_encoder_pretrained,
234
+ context_length=self.max_txt_l,
235
+ vocab_size=self.text_encoder_vocab_size,
236
+ checkpoint_num=0,
237
+ tokenizer_path=None if not 'tokenizer_path' in self.config.to_dict() else self.config.tokenizer_path
238
+ )
239
+ elif encoder_name == "vit_b16":
240
+ text_encoder = clip_text_b16(
241
+ pretrained=self.text_encoder_pretrained,
242
+ context_length=self.max_txt_l,
243
+ vocab_size=self.text_encoder_vocab_size,
244
+ checkpoint_num=0,
245
+ tokenizer_path=None if not 'tokenizer_path' in self.config.to_dict() else self.config.tokenizer_path
246
+ )
247
+ else:
248
+ raise NotImplementedError(f"Not implemented: {encoder_name}")
249
+
250
+ return text_encoder
251
+
252
+ def get_text_encoder(self):
253
+ """get text encoder, used for text and cross-modal encoding"""
254
+ encoder = self.text_encoder
255
+ return encoder.bert if hasattr(encoder, "bert") else encoder
256
+
257
+ def get_text_features(self, input_text, tokenizer, text_feature_dict={}):
258
+ if input_text in text_feature_dict:
259
+ return text_feature_dict[input_text]
260
+ text_template= f"{input_text}"
261
+ with torch.no_grad():
262
+ # text_token = tokenizer.encode(text_template).cuda()
263
+ text_features = self.encode_text(text_template).float()
264
+ text_features /= text_features.norm(dim=-1, keepdim=True)
265
+ text_feature_dict[input_text] = text_features
266
+ return text_features
267
+
268
+ def get_vid_features(self, input_frames):
269
+ with torch.no_grad():
270
+ clip_feat = self.encode_vision(input_frames,test=True).float()
271
+ clip_feat /= clip_feat.norm(dim=-1, keepdim=True)
272
+ return clip_feat
273
+
274
+ def get_predict_label(self, clip_feature, text_feats_tensor, top=5):
275
+ label_probs = (100.0 * clip_feature @ text_feats_tensor.T).softmax(dim=-1)
276
+ top_probs, top_labels = label_probs.cpu().topk(top, dim=-1)
277
+ return top_probs, top_labels
278
+
279
+
280
+ if __name__ =="__main__":
281
+ tokenizer = _Tokenizer()
viclip_text.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from collections import OrderedDict
4
+ from pkg_resources import packaging
5
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+ import torch.utils.checkpoint as checkpoint
12
+ import functools
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ # On P1, model extracted from https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K
18
+ MODEL_PATH = 'https://huggingface.co/laion'
19
+ _MODELS = {
20
+ "ViT-L/14": os.path.join(MODEL_PATH, "CLIP-ViT-L-14-DataComp.XL-s13B-b90K", "vit_l14_text.pth"),
21
+ "ViT-B/16": os.path.join(MODEL_PATH, "CLIP-ViT-B-16-DataComp.XL-s13B-b90K", "vit_b16_text.pth"),
22
+ }
23
+
24
+
25
+ class LayerNorm(nn.LayerNorm):
26
+ """Subclass torch's LayerNorm to handle fp16."""
27
+
28
+ def forward(self, x: torch.Tensor):
29
+ orig_type = x.dtype
30
+ ret = super().forward(x.type(torch.float32))
31
+ return ret.type(orig_type)
32
+
33
+
34
+ class QuickGELU(nn.Module):
35
+ def forward(self, x: torch.Tensor):
36
+ return x * torch.sigmoid(1.702 * x)
37
+
38
+
39
+ class ResidualAttentionBlock(nn.Module):
40
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
41
+ super().__init__()
42
+
43
+ self.attn = nn.MultiheadAttention(d_model, n_head)
44
+ self.ln_1 = LayerNorm(d_model)
45
+ self.mlp = nn.Sequential(OrderedDict([
46
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
47
+ ("gelu", QuickGELU()),
48
+ ("c_proj", nn.Linear(d_model * 4, d_model))
49
+ ]))
50
+ self.ln_2 = LayerNorm(d_model)
51
+ self.attn_mask = attn_mask
52
+
53
+ def attention(self, x: torch.Tensor):
54
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
55
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
56
+
57
+ def forward(self, x: torch.Tensor):
58
+ x = x + self.attention(self.ln_1(x))
59
+ x = x + self.mlp(self.ln_2(x))
60
+ return x
61
+
62
+
63
+ class Transformer(nn.Module):
64
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None,
65
+ checkpoint_num: int = 0):
66
+ super().__init__()
67
+ self.width = width
68
+ self.layers = layers
69
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
70
+
71
+ self.checkpoint_num = checkpoint_num
72
+
73
+ def forward(self, x: torch.Tensor):
74
+ if self.checkpoint_num > 0:
75
+ segments = min(self.checkpoint_num, len(self.resblocks))
76
+ return checkpoint.checkpoint_sequential(self.resblocks, segments, x)
77
+ else:
78
+ return self.resblocks(x)
79
+
80
+
81
+ class CLIP_TEXT(nn.Module):
82
+ def __init__(
83
+ self,
84
+ embed_dim: int,
85
+ context_length: int,
86
+ vocab_size: int,
87
+ transformer_width: int,
88
+ transformer_heads: int,
89
+ transformer_layers: int,
90
+ checkpoint_num: int,
91
+ tokenizer_path:str=None,
92
+ ):
93
+ super().__init__()
94
+
95
+ self.context_length = context_length
96
+ if tokenizer_path:
97
+ self._tokenizer = _Tokenizer(tokenizer_path)
98
+ else:
99
+ self._tokenizer = _Tokenizer()
100
+
101
+ self.transformer = Transformer(
102
+ width=transformer_width,
103
+ layers=transformer_layers,
104
+ heads=transformer_heads,
105
+ attn_mask=self.build_attention_mask(),
106
+ checkpoint_num=checkpoint_num,
107
+ )
108
+
109
+ self.vocab_size = vocab_size
110
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
111
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
112
+ self.ln_final = LayerNorm(transformer_width)
113
+
114
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
115
+
116
+ def no_weight_decay(self):
117
+ return {'token_embedding', 'positional_embedding'}
118
+
119
+ @functools.lru_cache(maxsize=None)
120
+ def build_attention_mask(self):
121
+ # lazily create causal attention mask, with full attention between the vision tokens
122
+ # pytorch uses additive attention mask; fill with -inf
123
+ mask = torch.empty(self.context_length, self.context_length)
124
+ mask.fill_(float("-inf"))
125
+ mask.triu_(1) # zero out the lower diagonal
126
+ return mask
127
+
128
+ def tokenize(self, texts, context_length=77, truncate=True):
129
+ """
130
+ Returns the tokenized representation of given input string(s)
131
+ Parameters
132
+ ----------
133
+ texts : Union[str, List[str]]
134
+ An input string or a list of input strings to tokenize
135
+ context_length : int
136
+ The context length to use; all CLIP models use 77 as the context length
137
+ truncate: bool
138
+ Whether to truncate the text in case its encoding is longer than the context length
139
+ Returns
140
+ -------
141
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
142
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
143
+ """
144
+ if isinstance(texts, str):
145
+ texts = [texts]
146
+
147
+ sot_token = self._tokenizer.encoder["<|startoftext|>"]
148
+ eot_token = self._tokenizer.encoder["<|endoftext|>"]
149
+ all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts]
150
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
151
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
152
+ else:
153
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
154
+
155
+ for i, tokens in enumerate(all_tokens):
156
+ if len(tokens) > context_length:
157
+ if truncate:
158
+ tokens = tokens[:context_length]
159
+ tokens[-1] = eot_token
160
+ else:
161
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
162
+ result[i, :len(tokens)] = torch.tensor(tokens)
163
+
164
+ return result
165
+
166
+ def forward(self, text):
167
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
168
+
169
+ x = x + self.positional_embedding
170
+ x = x.permute(1, 0, 2) # NLD -> LND
171
+ x = self.transformer(x)
172
+ x = x.permute(1, 0, 2) # LND -> NLD
173
+ x = self.ln_final(x)
174
+
175
+ # x.shape = [batch_size, n_ctx, transformer.width]
176
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
177
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
178
+
179
+ return x
180
+
181
+
182
+ def clip_text_b16(
183
+ embed_dim=512,
184
+ context_length=77,
185
+ vocab_size=49408,
186
+ transformer_width=512,
187
+ transformer_heads=8,
188
+ transformer_layers=12,
189
+ checkpoint_num=0,
190
+ pretrained=True,
191
+ tokenizer_path:str=None,
192
+ ):
193
+ # raise NotImplementedError
194
+ model = CLIP_TEXT(
195
+ embed_dim,
196
+ context_length,
197
+ vocab_size,
198
+ transformer_width,
199
+ transformer_heads,
200
+ transformer_layers,
201
+ checkpoint_num,
202
+ tokenizer_path,
203
+ )
204
+ # pretrained = _MODELS["ViT-B/16"]
205
+ # logger.info(f"Load pretrained weights from {pretrained}")
206
+ # state_dict = torch.load(pretrained, map_location='cpu')
207
+ # model.load_state_dict(state_dict, strict=False)
208
+ # return model.eval()
209
+ if pretrained:
210
+ if isinstance(pretrained, str) and pretrained != "bert-base-uncased":
211
+ pretrained = _MODELS[pretrained]
212
+ else:
213
+ pretrained = _MODELS["ViT-B/16"]
214
+ logger.info(f"Load pretrained weights from {pretrained}")
215
+ state_dict = torch.load(pretrained, map_location='cpu')
216
+ if context_length != state_dict["positional_embedding"].size(0):
217
+ # assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length."
218
+ print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}")
219
+ if context_length < state_dict["positional_embedding"].size(0):
220
+ state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length]
221
+ else:
222
+ state_dict["positional_embedding"] = F.pad(
223
+ state_dict["positional_embedding"],
224
+ (0, 0, 0, context_length - state_dict["positional_embedding"].size(0)),
225
+ value=0,
226
+ )
227
+
228
+ message = model.load_state_dict(state_dict, strict=False)
229
+ print(f"Load pretrained weights from {pretrained}: {message}")
230
+ return model.eval()
231
+
232
+
233
+ def clip_text_l14(
234
+ embed_dim=768,
235
+ context_length=77,
236
+ vocab_size=49408,
237
+ transformer_width=768,
238
+ transformer_heads=12,
239
+ transformer_layers=12,
240
+ checkpoint_num=0,
241
+ pretrained=True,
242
+ tokenizer_path:str=None,
243
+ ):
244
+ model = CLIP_TEXT(
245
+ embed_dim,
246
+ context_length,
247
+ vocab_size,
248
+ transformer_width,
249
+ transformer_heads,
250
+ transformer_layers,
251
+ checkpoint_num,
252
+ tokenizer_path,
253
+ )
254
+ if pretrained:
255
+ if isinstance(pretrained, str) and pretrained != "bert-base-uncased":
256
+ pretrained = _MODELS[pretrained]
257
+ else:
258
+ pretrained = _MODELS["ViT-L/14"]
259
+ logger.info(f"Load pretrained weights from {pretrained}")
260
+ state_dict = torch.load(pretrained, map_location='cpu')
261
+ if context_length != state_dict["positional_embedding"].size(0):
262
+ # assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length."
263
+ print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}")
264
+ if context_length < state_dict["positional_embedding"].size(0):
265
+ state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length]
266
+ else:
267
+ state_dict["positional_embedding"] = F.pad(
268
+ state_dict["positional_embedding"],
269
+ (0, 0, 0, context_length - state_dict["positional_embedding"].size(0)),
270
+ value=0,
271
+ )
272
+
273
+ message = model.load_state_dict(state_dict, strict=False)
274
+ print(f"Load pretrained weights from {pretrained}: {message}")
275
+ return model.eval()
276
+
277
+
278
+ def clip_text_l14_336(
279
+ embed_dim=768,
280
+ context_length=77,
281
+ vocab_size=49408,
282
+ transformer_width=768,
283
+ transformer_heads=12,
284
+ transformer_layers=12,
285
+ ):
286
+ raise NotImplementedError
287
+ model = CLIP_TEXT(
288
+ embed_dim,
289
+ context_length,
290
+ vocab_size,
291
+ transformer_width,
292
+ transformer_heads,
293
+ transformer_layers
294
+ )
295
+ pretrained = _MODELS["ViT-L/14_336"]
296
+ logger.info(f"Load pretrained weights from {pretrained}")
297
+ state_dict = torch.load(pretrained, map_location='cpu')
298
+ model.load_state_dict(state_dict, strict=False)
299
+ return model.eval()
300
+
301
+
302
+ def build_clip(config):
303
+ model_cls = config.text_encoder.clip_teacher
304
+ model = eval(model_cls)()
305
+ return model
viclip_vision.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ import logging
4
+ from collections import OrderedDict
5
+
6
+ import torch
7
+ from torch import nn
8
+ from einops import rearrange
9
+ from timm.models.layers import DropPath
10
+ from timm.models.registry import register_model
11
+
12
+ import torch.utils.checkpoint as checkpoint
13
+
14
+ # from models.utils import load_temp_embed_with_mismatch
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True):
19
+ """
20
+ Add/Remove extra temporal_embeddings as needed.
21
+ https://arxiv.org/abs/2104.00650 shows adding zero paddings works.
22
+
23
+ temp_embed_old: (1, num_frames_old, 1, d)
24
+ temp_embed_new: (1, num_frames_new, 1, d)
25
+ add_zero: bool, if True, add zero, else, interpolate trained embeddings.
26
+ """
27
+ # TODO zero pad
28
+ num_frms_new = temp_embed_new.shape[1]
29
+ num_frms_old = temp_embed_old.shape[1]
30
+ logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}")
31
+ if num_frms_new > num_frms_old:
32
+ if add_zero:
33
+ temp_embed_new[
34
+ :, :num_frms_old
35
+ ] = temp_embed_old # untrained embeddings are zeros.
36
+ else:
37
+ temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new)
38
+ elif num_frms_new < num_frms_old:
39
+ temp_embed_new = temp_embed_old[:, :num_frms_new]
40
+ else: # =
41
+ temp_embed_new = temp_embed_old
42
+ return temp_embed_new
43
+
44
+
45
+ # On P1, model extracted from https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K
46
+ MODEL_PATH = ''
47
+ _MODELS = {
48
+ "ViT-L/14": os.path.join(MODEL_PATH, "ViCLIP-L_InternVid-FLT-10M.pth"),
49
+ "ViT-B/16": os.path.join(MODEL_PATH, "ViCLIP-B-InternVid-FLT-10M.pth"),
50
+ }
51
+
52
+
53
+ class QuickGELU(nn.Module):
54
+ def forward(self, x):
55
+ return x * torch.sigmoid(1.702 * x)
56
+
57
+
58
+ class ResidualAttentionBlock(nn.Module):
59
+ def __init__(self, d_model, n_head, drop_path=0., attn_mask=None, dropout=0.):
60
+ super().__init__()
61
+
62
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
63
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
64
+ # logger.info(f'Droppath: {drop_path}')
65
+ self.attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout)
66
+ self.ln_1 = nn.LayerNorm(d_model)
67
+ self.mlp = nn.Sequential(OrderedDict([
68
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
69
+ ("gelu", QuickGELU()),
70
+ ("drop1", nn.Dropout(dropout)),
71
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
72
+ ("drop2", nn.Dropout(dropout)),
73
+ ]))
74
+ self.ln_2 = nn.LayerNorm(d_model)
75
+ self.attn_mask = attn_mask
76
+
77
+ def attention(self, x):
78
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
79
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
80
+
81
+ def forward(self, x):
82
+ x = x + self.drop_path1(self.attention(self.ln_1(x)))
83
+ x = x + self.drop_path2(self.mlp(self.ln_2(x)))
84
+ return x
85
+
86
+
87
+ class Transformer(nn.Module):
88
+ def __init__(self, width, layers, heads, drop_path=0., checkpoint_num=0, dropout=0.):
89
+ super().__init__()
90
+ dpr = [x.item() for x in torch.linspace(0, drop_path, layers)]
91
+ self.resblocks = nn.ModuleList()
92
+ for idx in range(layers):
93
+ self.resblocks.append(ResidualAttentionBlock(width, heads, drop_path=dpr[idx], dropout=dropout))
94
+ self.checkpoint_num = checkpoint_num
95
+
96
+ def forward(self, x):
97
+ for idx, blk in enumerate(self.resblocks):
98
+ if idx < self.checkpoint_num:
99
+ x = checkpoint.checkpoint(blk, x)
100
+ else:
101
+ x = blk(x)
102
+ return x
103
+
104
+
105
+ class VisionTransformer(nn.Module):
106
+ def __init__(
107
+ self, input_resolution, patch_size, width, layers, heads, output_dim=None,
108
+ kernel_size=1, num_frames=8, drop_path=0, checkpoint_num=0, dropout=0.,
109
+ temp_embed=True,
110
+ ):
111
+ super().__init__()
112
+ self.output_dim = output_dim
113
+ self.conv1 = nn.Conv3d(
114
+ 3, width,
115
+ (kernel_size, patch_size, patch_size),
116
+ (kernel_size, patch_size, patch_size),
117
+ (0, 0, 0), bias=False
118
+ )
119
+
120
+ scale = width ** -0.5
121
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
122
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
123
+ self.ln_pre = nn.LayerNorm(width)
124
+ if temp_embed:
125
+ self.temporal_positional_embedding = nn.Parameter(torch.zeros(1, num_frames, width))
126
+
127
+ self.transformer = Transformer(
128
+ width, layers, heads, drop_path=drop_path, checkpoint_num=checkpoint_num,
129
+ dropout=dropout)
130
+
131
+ self.ln_post = nn.LayerNorm(width)
132
+ if output_dim is not None:
133
+ self.proj = nn.Parameter(torch.empty(width, output_dim))
134
+ else:
135
+ self.proj = None
136
+
137
+ self.dropout = nn.Dropout(dropout)
138
+
139
+ def get_num_layers(self):
140
+ return len(self.transformer.resblocks)
141
+
142
+ @torch.jit.ignore
143
+ def no_weight_decay(self):
144
+ return {'positional_embedding', 'class_embedding', 'temporal_positional_embedding'}
145
+
146
+ def mask_tokens(self, inputs, masking_prob=0.0):
147
+ B, L, _ = inputs.shape
148
+
149
+ # This is different from text as we are masking a fix number of tokens
150
+ Lm = int(masking_prob * L)
151
+ masked_indices = torch.zeros(B, L)
152
+ indices = torch.argsort(torch.rand_like(masked_indices), dim=-1)[:, :Lm]
153
+ batch_indices = (
154
+ torch.arange(masked_indices.shape[0]).unsqueeze(-1).expand_as(indices)
155
+ )
156
+ masked_indices[batch_indices, indices] = 1
157
+
158
+ masked_indices = masked_indices.bool()
159
+
160
+ return inputs[~masked_indices].reshape(B, -1, inputs.shape[-1])
161
+
162
+ def forward(self, x, masking_prob=0.0):
163
+ x = self.conv1(x) # shape = [*, width, grid, grid]
164
+ B, C, T, H, W = x.shape
165
+ x = x.permute(0, 2, 3, 4, 1).reshape(B * T, H * W, C)
166
+
167
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
168
+ x = x + self.positional_embedding.to(x.dtype)
169
+
170
+ # temporal pos
171
+ cls_tokens = x[:B, :1, :]
172
+ x = x[:, 1:]
173
+ x = rearrange(x, '(b t) n m -> (b n) t m', b=B, t=T)
174
+ if hasattr(self, 'temporal_positional_embedding'):
175
+ if x.size(1) == 1:
176
+ # This is a workaround for unused parameter issue
177
+ x = x + self.temporal_positional_embedding.mean(1)
178
+ else:
179
+ x = x + self.temporal_positional_embedding
180
+ x = rearrange(x, '(b n) t m -> b (n t) m', b=B, t=T)
181
+
182
+ if masking_prob > 0.0:
183
+ x = self.mask_tokens(x, masking_prob)
184
+
185
+ x = torch.cat((cls_tokens, x), dim=1)
186
+
187
+ x = self.ln_pre(x)
188
+
189
+ x = x.permute(1, 0, 2) #BND -> NBD
190
+ x = self.transformer(x)
191
+
192
+ x = self.ln_post(x)
193
+
194
+ if self.proj is not None:
195
+ x = self.dropout(x[0]) @ self.proj
196
+ else:
197
+ x = x.permute(1, 0, 2) #NBD -> BND
198
+
199
+ return x
200
+
201
+
202
+ def inflate_weight(weight_2d, time_dim, center=True):
203
+ logger.info(f'Init center: {center}')
204
+ if center:
205
+ weight_3d = torch.zeros(*weight_2d.shape)
206
+ weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
207
+ middle_idx = time_dim // 2
208
+ weight_3d[:, :, middle_idx, :, :] = weight_2d
209
+ else:
210
+ weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
211
+ weight_3d = weight_3d / time_dim
212
+ return weight_3d
213
+
214
+
215
+ def load_state_dict(model, state_dict, input_resolution=224, patch_size=16, center=True):
216
+ state_dict_3d = model.state_dict()
217
+ for k in state_dict.keys():
218
+ if k in state_dict_3d.keys() and state_dict[k].shape != state_dict_3d[k].shape:
219
+ if len(state_dict_3d[k].shape) <= 2:
220
+ logger.info(f'Ignore: {k}')
221
+ continue
222
+ logger.info(f'Inflate: {k}, {state_dict[k].shape} => {state_dict_3d[k].shape}')
223
+ time_dim = state_dict_3d[k].shape[2]
224
+ state_dict[k] = inflate_weight(state_dict[k], time_dim, center=center)
225
+
226
+ pos_embed_checkpoint = state_dict['positional_embedding']
227
+ embedding_size = pos_embed_checkpoint.shape[-1]
228
+ num_patches = (input_resolution // patch_size) ** 2
229
+ orig_size = int((pos_embed_checkpoint.shape[-2] - 1) ** 0.5)
230
+ new_size = int(num_patches ** 0.5)
231
+ if orig_size != new_size:
232
+ logger.info(f'Pos_emb from {orig_size} to {new_size}')
233
+ extra_tokens = pos_embed_checkpoint[:1]
234
+ pos_tokens = pos_embed_checkpoint[1:]
235
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
236
+ pos_tokens = torch.nn.functional.interpolate(
237
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
238
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(0, 2)
239
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0)
240
+ state_dict['positional_embedding'] = new_pos_embed
241
+
242
+ message = model.load_state_dict(state_dict, strict=False)
243
+ logger.info(f"Load pretrained weights: {message}")
244
+
245
+
246
+ @register_model
247
+ def clip_joint_b16(
248
+ pretrained=False, input_resolution=224, kernel_size=1,
249
+ center=True, num_frames=8, drop_path=0., checkpoint_num=0,
250
+ dropout=0.,
251
+ ):
252
+ model = VisionTransformer(
253
+ input_resolution=input_resolution, patch_size=16,
254
+ width=768, layers=12, heads=12, output_dim=512,
255
+ kernel_size=kernel_size, num_frames=num_frames,
256
+ drop_path=drop_path, checkpoint_num=checkpoint_num,
257
+ dropout=dropout,
258
+ )
259
+ # raise NotImplementedError
260
+ if pretrained:
261
+ if isinstance(pretrained, str):
262
+ model_name = pretrained
263
+ else:
264
+ model_name = "ViT-B/16"
265
+
266
+ logger.info('load pretrained weights')
267
+ state_dict = torch.load(_MODELS[model_name], map_location='cpu')
268
+ load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=16, center=center)
269
+ return model.eval()
270
+
271
+
272
+ @register_model
273
+ def clip_joint_l14(
274
+ pretrained=False, input_resolution=224, kernel_size=1,
275
+ center=True, num_frames=8, drop_path=0., checkpoint_num=0,
276
+ dropout=0.,
277
+ ):
278
+ model = VisionTransformer(
279
+ input_resolution=input_resolution, patch_size=14,
280
+ width=1024, layers=24, heads=16, output_dim=768,
281
+ kernel_size=kernel_size, num_frames=num_frames,
282
+ drop_path=drop_path, checkpoint_num=checkpoint_num,
283
+ dropout=dropout,
284
+ )
285
+
286
+ if pretrained:
287
+ if isinstance(pretrained, str):
288
+ model_name = pretrained
289
+ else:
290
+ model_name = "ViT-L/14"
291
+ logger.info('load pretrained weights')
292
+ state_dict = torch.load(_MODELS[model_name], map_location='cpu')
293
+ load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=14, center=center)
294
+ return model.eval()
295
+
296
+
297
+ @register_model
298
+ def clip_joint_l14_336(
299
+ pretrained=True, input_resolution=336, kernel_size=1,
300
+ center=True, num_frames=8, drop_path=0.
301
+ ):
302
+ raise NotImplementedError
303
+ model = VisionTransformer(
304
+ input_resolution=input_resolution, patch_size=14,
305
+ width=1024, layers=24, heads=16, output_dim=768,
306
+ kernel_size=kernel_size, num_frames=num_frames,
307
+ drop_path=drop_path,
308
+ )
309
+ if pretrained:
310
+ logger.info('load pretrained weights')
311
+ state_dict = torch.load(_MODELS["ViT-L/14_336"], map_location='cpu')
312
+ load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=14, center=center)
313
+ return model.eval()
314
+
315
+
316
+ def interpolate_pos_embed_vit(state_dict, new_model):
317
+ key = "vision_encoder.temporal_positional_embedding"
318
+ if key in state_dict:
319
+ vision_temp_embed_new = new_model.state_dict()[key]
320
+ vision_temp_embed_new = vision_temp_embed_new.unsqueeze(2) # [1, n, d] -> [1, n, 1, d]
321
+ vision_temp_embed_old = state_dict[key]
322
+ vision_temp_embed_old = vision_temp_embed_old.unsqueeze(2)
323
+
324
+ state_dict[key] = load_temp_embed_with_mismatch(
325
+ vision_temp_embed_old, vision_temp_embed_new, add_zero=False
326
+ ).squeeze(2)
327
+
328
+ key = "text_encoder.positional_embedding"
329
+ if key in state_dict:
330
+ text_temp_embed_new = new_model.state_dict()[key]
331
+ text_temp_embed_new = text_temp_embed_new.unsqueeze(0).unsqueeze(2) # [n, d] -> [1, n, 1, d]
332
+ text_temp_embed_old = state_dict[key]
333
+ text_temp_embed_old = text_temp_embed_old.unsqueeze(0).unsqueeze(2)
334
+
335
+ state_dict[key] = load_temp_embed_with_mismatch(
336
+ text_temp_embed_old, text_temp_embed_new, add_zero=False
337
+ ).squeeze(2).squeeze(0)
338
+ return state_dict
339
+
340
+
341
+ if __name__ == '__main__':
342
+ import time
343
+ from fvcore.nn import FlopCountAnalysis
344
+ from fvcore.nn import flop_count_table
345
+ import numpy as np
346
+
347
+ seed = 4217
348
+ np.random.seed(seed)
349
+ torch.manual_seed(seed)
350
+ torch.cuda.manual_seed(seed)
351
+ torch.cuda.manual_seed_all(seed)
352
+ num_frames = 8
353
+
354
+ # model = clip_joint_b16(pretrained=True, kernel_size=1, num_frames=8, num_classes=400, drop_path=0.1)
355
+ # logger.info(model)
356
+ model = clip_joint_l14(pretrained=False)
357
+
358
+ flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, 224, 224))
359
+ s = time.time()
360
+ logger.info(flop_count_table(flops, max_depth=1))
361
+ logger.info(time.time()-s)
362
+ # logger.info(model(torch.rand(1, 3, num_frames, 224, 224)).shape)