grostaco commited on
Commit
ba2ab36
1 Parent(s): b4c1a13

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
README.md CHANGED
@@ -9,5 +9,6 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
- f
 
 
9
  pinned: false
10
  ---
11
 
12
+ # IRRA space
13
+
14
+ Space for Text-To-Image Person retrieval for the [IRRA](https://github.com/anosorae/IRRA/tree/main) model
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from lib.utils.model import get_model, get_similarities
3
+ from PIL import Image
4
+
5
+ st.title('IRRA Text-To-Image-Retrival')
6
+
7
+ st.header('Inputs')
8
+ caption = st.text_input('Description Input')
9
+
10
+ images = st.file_uploader('Upload images', accept_multiple_files=True)
11
+ if images is not None:
12
+ st.image(images) # type: ignore
13
+
14
+ st.header('Options')
15
+ st.subheader('Ranks')
16
+
17
+ ranks = st.slider('slider_ranks', min_value=1, max_value=10, label_visibility='collapsed',value=5)
18
+
19
+ button = st.button('Match most similar', disabled=len(images) == 0 or caption == '')
20
+
21
+ if button:
22
+ st.header('Results')
23
+ with st.spinner('Loading model'):
24
+ model = get_model()
25
+
26
+ st.text(f'IRRA model loaded with {sum(p.numel() for p in model.parameters()) / 1e6:.0f}M parameters')
27
+
28
+ with st.spinner('Computing and ranking similarities'):
29
+ similarities = get_similarities(caption, images, model)
30
+
31
+ indices = similarities.argsort(descending=True).squeeze(0).cpu().tolist()[:ranks]
32
+
33
+ for i, idx in enumerate(indices):
34
+ c1, c2 = st.columns(2)
35
+ with c1:
36
+ st.text(f'Rank {i + 1}')
37
+ with c2:
38
+ st.image(images[idx])
39
+
40
+
lib/IRRA/image.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as T
3
+
4
+ from PIL import Image
5
+
6
+ def prepare_images(files: list[str]):
7
+ mean = [0.48145466, 0.4578275, 0.40821073]
8
+ std = [0.26862954, 0.26130258, 0.27577711]
9
+
10
+ transforms = T.Compose([
11
+ T.Resize((384, 128)),
12
+ T.RandomHorizontalFlip(0.5),
13
+ T.ToTensor(),
14
+ T.Normalize(mean=mean, std=std),
15
+ ])
16
+
17
+ tensors = []
18
+ for file in files:
19
+ tensors.append(transforms(Image.open(file).convert('RGB')).unsqueeze(0))
20
+
21
+ return torch.cat(tensors, dim=0)
22
+
23
+
lib/IRRA/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .build import build_model
lib/IRRA/model/build.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import objectives
2
+ from .clip_model import Transformer, QuickGELU, LayerNorm, build_CLIP_from_openai_pretrained, convert_weights
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from collections import OrderedDict
7
+
8
+
9
+ class IRRA(nn.Module):
10
+ def __init__(self, args, num_classes=11003):
11
+ super().__init__()
12
+ self.args = args
13
+ self.num_classes = num_classes
14
+ self._set_task()
15
+
16
+ self.base_model, base_cfg = build_CLIP_from_openai_pretrained(args.pretrain_choice, args.img_size, args.stride_size)
17
+ self.embed_dim = base_cfg['embed_dim']
18
+
19
+ self.logit_scale = torch.ones([]) * (1 / args.temperature)
20
+
21
+ if 'id' in args.loss_names:
22
+ self.classifier = nn.Linear(self.embed_dim, self.num_classes)
23
+ nn.init.normal_(self.classifier.weight.data, std=0.001)
24
+ nn.init.constant_(self.classifier.bias.data, val=0.0)
25
+
26
+ if 'mlm' in args.loss_names:
27
+ self.cross_attn = nn.MultiheadAttention(self.embed_dim,
28
+ self.embed_dim // 64,
29
+ batch_first=True)
30
+ self.cross_modal_transformer = Transformer(width=self.embed_dim,
31
+ layers=args.cmt_depth,
32
+ heads=self.embed_dim //
33
+ 64)
34
+ scale = self.cross_modal_transformer.width**-0.5
35
+
36
+ self.ln_pre_t = LayerNorm(self.embed_dim)
37
+ self.ln_pre_i = LayerNorm(self.embed_dim)
38
+ self.ln_post = LayerNorm(self.embed_dim)
39
+
40
+ proj_std = scale * ((2 * self.cross_modal_transformer.layers)**-0.5)
41
+ attn_std = scale
42
+ fc_std = (2 * self.cross_modal_transformer.width)**-0.5
43
+ for block in self.cross_modal_transformer.resblocks:
44
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
45
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
46
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
47
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
48
+
49
+ # init cross attn
50
+ nn.init.normal_(self.cross_attn.in_proj_weight, std=attn_std)
51
+ nn.init.normal_(self.cross_attn.out_proj.weight, std=proj_std)
52
+
53
+ self.mlm_head = nn.Sequential(
54
+ OrderedDict([('dense', nn.Linear(self.embed_dim, self.embed_dim)),
55
+ ('gelu', QuickGELU()),
56
+ ('ln', LayerNorm(self.embed_dim)),
57
+ ('fc', nn.Linear(self.embed_dim, args.vocab_size))]))
58
+ # init mlm head
59
+ nn.init.normal_(self.mlm_head.dense.weight, std=fc_std)
60
+ nn.init.normal_(self.mlm_head.fc.weight, std=proj_std)
61
+
62
+ def _set_task(self):
63
+ loss_names = self.args.loss_names
64
+ self.current_task = [l.strip() for l in loss_names.split('+')]
65
+ print(f'Training Model with {self.current_task} tasks')
66
+
67
+
68
+ def cross_former(self, q, k, v):
69
+ x = self.cross_attn(
70
+ self.ln_pre_t(q),
71
+ self.ln_pre_i(k),
72
+ self.ln_pre_i(v),
73
+ need_weights=False)[0]
74
+ x = x.permute(1, 0, 2) # NLD -> LND
75
+ x = self.cross_modal_transformer(x)
76
+ x = x.permute(1, 0, 2) # LND -> NLD
77
+
78
+ x = self.ln_post(x)
79
+ return x
80
+
81
+ def encode_image(self, image):
82
+ x = self.base_model.encode_image(image)
83
+ return x[:, 0, :].float()
84
+ # return x.float() # for CLIP ResNet visual model
85
+
86
+ def encode_text(self, text):
87
+ x = self.base_model.encode_text(text)
88
+ return x[torch.arange(x.shape[0]), text.argmax(dim=-1)].float()
89
+
90
+ def forward(self, batch):
91
+ ret = dict()
92
+
93
+ images = batch['images']
94
+ caption_ids = batch['caption_ids']
95
+ image_feats, text_feats = self.base_model(images, caption_ids)
96
+ i_feats = image_feats[:, 0, :].float()
97
+ # i_feats = image_feats.float() # for CLIP ResNet visual model
98
+ t_feats = text_feats[torch.arange(text_feats.shape[0]), caption_ids.argmax(dim=-1)].float()
99
+
100
+ logit_scale = self.logit_scale
101
+ ret.update({'temperature': 1 / logit_scale})
102
+
103
+ if 'itc' in self.current_task:
104
+ ret.update({'itc_loss':objectives.compute_itc(i_feats, t_feats, logit_scale)})
105
+
106
+ if 'sdm' in self.current_task:
107
+ ret.update({'sdm_loss':objectives.compute_sdm(i_feats, t_feats, batch['pids'], logit_scale)})
108
+
109
+ if 'cmpm' in self.current_task:
110
+ ret.update({'cmpm_loss':objectives.compute_cmpm(i_feats, t_feats, batch['pids'])})
111
+
112
+ if 'id' in self.current_task:
113
+ image_logits = self.classifier(i_feats.half()).float()
114
+ text_logits = self.classifier(t_feats.half()).float()
115
+ ret.update({'id_loss':objectives.compute_id(image_logits, text_logits, batch['pids'])*self.args.id_loss_weight})
116
+
117
+ image_pred = torch.argmax(image_logits, dim=1)
118
+ text_pred = torch.argmax(text_logits, dim=1)
119
+
120
+ image_precision = (image_pred == batch['pids']).float().mean()
121
+ text_precision = (text_pred == batch['pids']).float().mean()
122
+ ret.update({'img_acc': image_precision})
123
+ ret.update({'txt_acc': text_precision})
124
+
125
+ if 'mlm' in self.current_task:
126
+ mlm_ids = batch['mlm_ids']
127
+
128
+ mlm_feats = self.base_model.encode_text(mlm_ids)
129
+
130
+ x = self.cross_former(mlm_feats, image_feats, image_feats)
131
+
132
+ x = self.mlm_head(x) # [batch_size, text_len, num_colors]
133
+
134
+ scores = x.float().reshape(-1, self.args.vocab_size)
135
+ mlm_labels = batch['mlm_labels'].reshape(-1)
136
+ ret.update({'mlm_loss': objectives.compute_mlm(scores, mlm_labels)*self.args.mlm_loss_weight})
137
+
138
+ pred = scores.max(1)[1]
139
+ mlm_label_idx = torch.nonzero(mlm_labels)
140
+ acc = (pred[mlm_label_idx] == mlm_labels[mlm_label_idx]).float().mean()
141
+ ret.update({'mlm_acc': acc})
142
+
143
+ return ret
144
+
145
+
146
+ def build_model(args, num_classes=11003):
147
+ model = IRRA(args, num_classes)
148
+ # covert model to fp16
149
+ convert_weights(model)
150
+ return model
lib/IRRA/model/clip_model.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
3
+ """
4
+ from collections import OrderedDict
5
+ import logging
6
+ import math
7
+ import os
8
+ from typing import List, Tuple, Union
9
+ import hashlib
10
+ import urllib
11
+ from tqdm import tqdm
12
+ import warnings
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch import nn
17
+
18
+
19
+ logger = logging.getLogger("IRRA.model")
20
+
21
+ _MODELS = {
22
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
23
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
24
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
25
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
26
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
27
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
28
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
29
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
30
+ }
31
+
32
+ def available_models() -> List[str]:
33
+ """Returns the names of available CLIP models"""
34
+ return list(_MODELS.keys())
35
+
36
+ def _download(url: str, root: str):
37
+ os.makedirs(root, exist_ok=True)
38
+ filename = os.path.basename(url)
39
+
40
+ expected_sha256 = url.split("/")[-2]
41
+ download_target = os.path.join(root, filename)
42
+
43
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
44
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
45
+
46
+ if os.path.isfile(download_target):
47
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
48
+ return download_target
49
+ else:
50
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
51
+
52
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
53
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
54
+ while True:
55
+ buffer = source.read(8192)
56
+ if not buffer:
57
+ break
58
+
59
+ output.write(buffer)
60
+ loop.update(len(buffer))
61
+
62
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
63
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
64
+
65
+ return download_target
66
+
67
+
68
+ class Bottleneck(nn.Module):
69
+ expansion = 4
70
+
71
+ def __init__(self, inplanes, planes, stride=1):
72
+ super().__init__()
73
+
74
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
75
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
76
+ self.bn1 = nn.BatchNorm2d(planes)
77
+
78
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
79
+ self.bn2 = nn.BatchNorm2d(planes)
80
+
81
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
82
+
83
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
84
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
85
+
86
+ self.relu = nn.ReLU(inplace=True)
87
+ self.downsample = None
88
+ self.stride = stride
89
+
90
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
91
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
92
+ self.downsample = nn.Sequential(OrderedDict([
93
+ ("-1", nn.AvgPool2d(stride)),
94
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
95
+ ("1", nn.BatchNorm2d(planes * self.expansion))
96
+ ]))
97
+
98
+ def forward(self, x: torch.Tensor):
99
+ identity = x
100
+
101
+ out = self.relu(self.bn1(self.conv1(x)))
102
+ out = self.relu(self.bn2(self.conv2(out)))
103
+ out = self.avgpool(out)
104
+ out = self.bn3(self.conv3(out))
105
+
106
+ if self.downsample is not None:
107
+ identity = self.downsample(x)
108
+
109
+ out += identity
110
+ out = self.relu(out)
111
+ return out
112
+
113
+
114
+ class AttentionPool2d(nn.Module):
115
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
116
+ super().__init__()
117
+ # self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
118
+ self.positional_embedding = nn.Parameter(torch.randn((spacial_dim[0] * spacial_dim[1]) + 1, embed_dim)/ embed_dim ** 0.5)
119
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
120
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
121
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
122
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
123
+ self.num_heads = num_heads
124
+
125
+ def forward(self, x):
126
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
127
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
128
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
129
+ x, _ = F.multi_head_attention_forward(
130
+ query=x, key=x, value=x,
131
+ embed_dim_to_check=x.shape[-1],
132
+ num_heads=self.num_heads,
133
+ q_proj_weight=self.q_proj.weight,
134
+ k_proj_weight=self.k_proj.weight,
135
+ v_proj_weight=self.v_proj.weight,
136
+ in_proj_weight=None,
137
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
138
+ bias_k=None,
139
+ bias_v=None,
140
+ add_zero_attn=False,
141
+ dropout_p=0,
142
+ out_proj_weight=self.c_proj.weight,
143
+ out_proj_bias=self.c_proj.bias,
144
+ use_separate_proj_weight=True,
145
+ training=self.training,
146
+ need_weights=False
147
+ )
148
+
149
+ return x[0]
150
+
151
+
152
+ class ModifiedResNet(nn.Module):
153
+ """
154
+ A ResNet class that is similar to torchvision's but contains the following changes:
155
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
156
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
157
+ - The final pooling layer is a QKV attention instead of an average pool
158
+ """
159
+
160
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
161
+ super().__init__()
162
+ self.output_dim = output_dim
163
+ self.input_resolution = input_resolution
164
+
165
+ # the 3-layer stem
166
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
167
+ self.bn1 = nn.BatchNorm2d(width // 2)
168
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
169
+ self.bn2 = nn.BatchNorm2d(width // 2)
170
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
171
+ self.bn3 = nn.BatchNorm2d(width)
172
+ self.avgpool = nn.AvgPool2d(2)
173
+ self.relu = nn.ReLU(inplace=True)
174
+
175
+ # residual layers
176
+ self._inplanes = width # this is a *mutable* variable used during construction
177
+ self.layer1 = self._make_layer(width, layers[0])
178
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
179
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
180
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
181
+
182
+ embed_dim = width * 32 # the ResNet feature dimension
183
+ spacial_dim = (
184
+ input_resolution[0] // 32,
185
+ input_resolution[1] // 32,
186
+ )
187
+ self.attnpool = AttentionPool2d(spacial_dim, embed_dim, heads, output_dim)
188
+
189
+ def _make_layer(self, planes, blocks, stride=1):
190
+ layers = [Bottleneck(self._inplanes, planes, stride)]
191
+
192
+ self._inplanes = planes * Bottleneck.expansion
193
+ for _ in range(1, blocks):
194
+ layers.append(Bottleneck(self._inplanes, planes))
195
+
196
+ return nn.Sequential(*layers)
197
+
198
+ def forward(self, x):
199
+ def stem(x):
200
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
201
+ x = self.relu(bn(conv(x)))
202
+ x = self.avgpool(x)
203
+ return x
204
+
205
+ x = x.type(self.conv1.weight.dtype)
206
+ x = stem(x)
207
+ x = self.layer1(x)
208
+ x = self.layer2(x)
209
+ x = self.layer3(x)
210
+ x = self.layer4(x)
211
+ x = self.attnpool(x)
212
+
213
+ return x
214
+
215
+
216
+ class LayerNorm(nn.LayerNorm):
217
+ """Subclass torch's LayerNorm to handle fp16."""
218
+
219
+ def forward(self, x: torch.Tensor):
220
+ orig_type = x.dtype
221
+ ret = super().forward(x.type(torch.float32))
222
+ return ret.type(orig_type)
223
+
224
+
225
+ class QuickGELU(nn.Module):
226
+ def forward(self, x: torch.Tensor):
227
+ return x * torch.sigmoid(1.702 * x)
228
+
229
+
230
+ class ResidualAttentionBlock(nn.Module):
231
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
232
+ super().__init__()
233
+
234
+ self.attn = nn.MultiheadAttention(d_model, n_head)
235
+ self.ln_1 = LayerNorm(d_model)
236
+ self.mlp = nn.Sequential(OrderedDict([
237
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
238
+ ("gelu", QuickGELU()),
239
+ ("c_proj", nn.Linear(d_model * 4, d_model))
240
+ ]))
241
+ self.ln_2 = LayerNorm(d_model)
242
+ self.attn_mask = attn_mask
243
+
244
+ def attention(self, x: torch.Tensor):
245
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
246
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
247
+
248
+ def forward(self, x: torch.Tensor):
249
+ x = x + self.attention(self.ln_1(x))
250
+ x = x + self.mlp(self.ln_2(x))
251
+ return x
252
+
253
+
254
+ class Transformer(nn.Module):
255
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
256
+ super().__init__()
257
+ self.width = width
258
+ self.layers = layers
259
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
260
+
261
+ def forward(self, x: torch.Tensor):
262
+ return self.resblocks(x)
263
+
264
+
265
+ class VisionTransformer(nn.Module):
266
+ def __init__(self, input_resolution: Tuple[int, int], patch_size: int, stride_size: int, width: int, layers: int, heads: int, output_dim: int):
267
+ super().__init__()
268
+ self.input_resolution = input_resolution # (384, 128)
269
+ self.num_x = (input_resolution[1] - patch_size) // stride_size + 1
270
+ self.num_y = (input_resolution[0] - patch_size) // stride_size + 1
271
+ num_patches = self.num_x * self.num_y
272
+
273
+ self.output_dim = output_dim
274
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=stride_size, bias=False)
275
+
276
+ scale = width ** -0.5 # 1/sqrt(768)
277
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
278
+ self.positional_embedding = nn.Parameter(scale * torch.randn(num_patches + 1, width))
279
+ self.ln_pre = LayerNorm(width)
280
+
281
+ self.transformer = Transformer(width, layers, heads)
282
+
283
+ self.ln_post = LayerNorm(width)
284
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
285
+
286
+
287
+ def forward(self, x: torch.Tensor):
288
+ x = self.conv1(x) # shape = [*, width, grid, grid]
289
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
290
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
291
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
292
+ x = x + self.positional_embedding.to(x.dtype)
293
+ x = self.ln_pre(x)
294
+
295
+ x = x.permute(1, 0, 2) # NLD -> LND
296
+ x = self.transformer(x)
297
+ x = x.permute(1, 0, 2) # LND -> NLD
298
+
299
+ # x = self.ln_post(x[:, 0, :])
300
+ x = self.ln_post(x)
301
+
302
+ if self.proj is not None:
303
+ x = x @ self.proj
304
+
305
+ return x
306
+
307
+
308
+
309
+ class CLIP(nn.Module):
310
+ def __init__(self,
311
+ embed_dim: int,
312
+ # vision
313
+ image_resolution: Union[int, Tuple[int, int]],
314
+ vision_layers: Union[Tuple[int, int, int, int], int],
315
+ vision_width: int,
316
+ vision_patch_size: int,
317
+ stride_size: int,
318
+ # text
319
+ context_length: int,
320
+ vocab_size: int,
321
+ transformer_width: int,
322
+ transformer_heads: int,
323
+ transformer_layers: int
324
+ ):
325
+ super().__init__()
326
+
327
+ self.context_length = context_length
328
+
329
+ if isinstance(vision_layers, (tuple, list)):
330
+ vision_heads = vision_width * 32 // 64
331
+ self.visual = ModifiedResNet(
332
+ layers=vision_layers,
333
+ output_dim=embed_dim,
334
+ heads=vision_heads,
335
+ input_resolution=image_resolution,
336
+ width=vision_width
337
+ )
338
+ else:
339
+ vision_heads = vision_width // 64
340
+ self.visual = VisionTransformer(
341
+ input_resolution=image_resolution,
342
+ patch_size=vision_patch_size,
343
+ stride_size=stride_size,
344
+ width=vision_width,
345
+ layers=vision_layers,
346
+ heads=vision_heads,
347
+ output_dim=embed_dim
348
+ )
349
+
350
+ self.transformer = Transformer(
351
+ width=transformer_width,
352
+ layers=transformer_layers,
353
+ heads=transformer_heads,
354
+ attn_mask=self.build_attention_mask()
355
+ )
356
+
357
+ self.vocab_size = vocab_size
358
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
359
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
360
+ self.ln_final = LayerNorm(transformer_width)
361
+
362
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
363
+ # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
364
+
365
+ self.initialize_parameters()
366
+
367
+ def initialize_parameters(self):
368
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
369
+ nn.init.normal_(self.positional_embedding, std=0.01)
370
+
371
+ if isinstance(self.visual, ModifiedResNet):
372
+ if self.visual.attnpool is not None:
373
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
374
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
375
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
376
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
377
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
378
+
379
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
380
+ for name, param in resnet_block.named_parameters():
381
+ if name.endswith("bn3.weight"):
382
+ nn.init.zeros_(param)
383
+
384
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
385
+ attn_std = self.transformer.width ** -0.5
386
+ fc_std = (2 * self.transformer.width) ** -0.5
387
+ for block in self.transformer.resblocks:
388
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
389
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
390
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
391
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
392
+
393
+ if self.text_projection is not None:
394
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
395
+
396
+ def build_attention_mask(self):
397
+ # lazily create causal attention mask, with full attention between the vision tokens
398
+ # pytorch uses additive attention mask; fill with -inf
399
+ mask = torch.empty(self.context_length, self.context_length)
400
+ mask.fill_(float("-inf"))
401
+ mask.triu_(1) # zero out the lower diagonal
402
+ return mask
403
+
404
+ @property
405
+ def dtype(self):
406
+ return self.visual.conv1.weight.dtype
407
+
408
+ def encode_image(self, image):
409
+ return self.visual(image.type(self.dtype))
410
+
411
+ def encode_text(self, text):
412
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
413
+
414
+ x = x + self.positional_embedding.type(self.dtype)
415
+ x = x.permute(1, 0, 2) # NLD -> LND
416
+ x = self.transformer(x)
417
+ x = x.permute(1, 0, 2) # LND -> NLD
418
+ x = self.ln_final(x).type(self.dtype)
419
+
420
+ # x.shape = [batch_size, n_ctx, transformer.width]
421
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
422
+ # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
423
+ x = x @ self.text_projection
424
+
425
+ return x
426
+
427
+ def forward(self, image, text):
428
+ image_features = self.encode_image(image)
429
+ text_features = self.encode_text(text)
430
+
431
+ # # normalized features
432
+ # image_features = image_features / image_features.norm(dim=-1, keepdim=True)
433
+ # text_features = text_features / text_features.norm(dim=-1, keepdim=True)
434
+
435
+ # # cosine similarity as logits
436
+ # logit_scale = self.logit_scale.exp()
437
+ # logits_per_image = logit_scale * image_features @ text_features.t()
438
+ # logits_per_text = logits_per_image.t()
439
+
440
+ # # shape = [global_batch_size, global_batch_size]
441
+ # return logits_per_image, logits_per_text
442
+
443
+ return image_features, text_features
444
+
445
+
446
+ def load_param(self, state_dict):
447
+ # 将pretrained_dict里不属于model_dict的键剔除掉
448
+ param_dict = {k: v for k, v in state_dict.items() if k in self.state_dict()}
449
+
450
+ if 'model' in param_dict:
451
+ param_dict = param_dict['model']
452
+ if 'state_dict' in param_dict:
453
+ param_dict = param_dict['state_dict']
454
+ for k, v in param_dict.items():
455
+ if k == 'visual.positional_embedding' and v.shape != self.visual.positional_embedding.shape:
456
+ v = resize_pos_embed(v, self.visual.positional_embedding, self.visual.num_y, self.visual.num_x)
457
+ elif k == 'positional_embedding' and v.shape != self.positional_embedding.shape:
458
+ v = resize_text_pos_embed(v, self.context_length)
459
+ try:
460
+ self.state_dict()[k].copy_(v)
461
+ except:
462
+ print(f'===========================ERROR occur in copy {k}, {v.shape}=========================')
463
+ print('shape do not match in k :{}: param_dict{} vs self.state_dict(){}'.format(k, v.shape, self.state_dict()[k].shape))
464
+
465
+
466
+
467
+ def resize_pos_embed(posemb, posemb_new, hight, width):
468
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
469
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
470
+ posemb = posemb.unsqueeze(0)
471
+ posemb_new = posemb_new.unsqueeze(0)
472
+
473
+ posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:]
474
+
475
+ gs_old = int(math.sqrt(len(posemb_grid)))
476
+ print('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape, posemb_new.shape, hight, width))
477
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
478
+ posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear')
479
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1)
480
+ posemb = torch.cat([posemb_token, posemb_grid], dim=1)
481
+ return posemb.squeeze(0)
482
+
483
+
484
+ def convert_weights(model: nn.Module):
485
+ """Convert applicable model parameters to fp16"""
486
+
487
+ def _convert_weights_to_fp16(l):
488
+ # if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
489
+ # l.weight.data = l.weight.data.half()
490
+ # if l.bias is not None:
491
+ # l.bias.data = l.bias.data.half()
492
+
493
+ # if isinstance(l, nn.MultiheadAttention):
494
+ # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
495
+ # tensor = getattr(l, attr)
496
+ # if tensor is not None:
497
+ # tensor.data = tensor.data.half()
498
+
499
+ # for name in ["text_projection", "proj", "mcq_proj"]:
500
+ # if hasattr(l, name):
501
+ # attr = getattr(l, name)
502
+ # if attr is not None:
503
+ # attr.data = attr.data.half()
504
+ ...
505
+
506
+ model.apply(_convert_weights_to_fp16)
507
+
508
+
509
+ def build_CLIP_from_openai_pretrained(name: str, image_size: Union[int, Tuple[int, int]], stride_size: int, jit: bool = False, download_root: str = None):
510
+ """Load a CLIP model
511
+
512
+ Parameters
513
+ ----------
514
+ name : str
515
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
516
+
517
+ image_size: Union[int, Tuple[int, int]]
518
+ Input image size, in Re-ID task, image size commonly set to 384x128, instead of 224x224
519
+
520
+ jit : bool
521
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
522
+
523
+ download_root: str
524
+ path to download the model files; by default, it uses "~/.cache/clip"
525
+
526
+ Returns
527
+ -------
528
+ model : torch.nn.Module
529
+ The CLIP model
530
+ """
531
+ if name in _MODELS:
532
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
533
+ elif os.path.isfile(name):
534
+ model_path = name
535
+ else:
536
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
537
+
538
+ try:
539
+ # loading JIT archive
540
+ model = torch.jit.load(model_path, map_location="cpu")
541
+ state_dict = None
542
+ except RuntimeError:
543
+ # loading saved state dict
544
+ if jit:
545
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
546
+ jit = False
547
+ state_dict = torch.load(model_path, map_location="cpu")
548
+
549
+ state_dict = state_dict or model.state_dict()
550
+
551
+ vit = "visual.proj" in state_dict
552
+
553
+ if vit:
554
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
555
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
556
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
557
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
558
+ image_resolution = vision_patch_size * grid_size
559
+ else:
560
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
561
+ vision_layers = tuple(counts)
562
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
563
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
564
+ vision_patch_size = None
565
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
566
+ image_resolution = output_width * 32
567
+
568
+ embed_dim = state_dict["text_projection"].shape[1]
569
+ context_length = state_dict["positional_embedding"].shape[0]
570
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
571
+ transformer_width = state_dict["ln_final.weight"].shape[0]
572
+ transformer_heads = transformer_width // 64
573
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
574
+
575
+ model_cfg = {
576
+ 'embed_dim': embed_dim,
577
+ 'image_resolution': image_resolution,
578
+ 'vision_layers': vision_layers,
579
+ 'vision_width': vision_width,
580
+ 'vision_patch_size': vision_patch_size,
581
+ 'context_length': context_length,
582
+ 'vocab_size': vocab_size,
583
+ 'transformer_width': transformer_width,
584
+ 'transformer_heads': transformer_heads,
585
+ 'transformer_layers': transformer_layers
586
+ }
587
+
588
+
589
+ # modify image resolution to adapt Re-ID task
590
+ model_cfg['image_resolution'] = image_size
591
+ model_cfg['stride_size'] = stride_size
592
+ logger.info(f"Load pretrained {name} CLIP model with model config: {model_cfg}")
593
+ model = CLIP(**model_cfg)
594
+
595
+ # covert model to fp16
596
+ # convert_weights(model)
597
+
598
+ # resize modified pos embedding
599
+ model.load_param(state_dict)
600
+ return model, model_cfg
601
+
602
+
lib/IRRA/model/objectives.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def compute_sdm(image_fetures, text_fetures, pid, logit_scale, image_id=None, factor=0.3, epsilon=1e-8):
7
+ """
8
+ Similarity Distribution Matching
9
+ """
10
+ batch_size = image_fetures.shape[0]
11
+ pid = pid.reshape((batch_size, 1)) # make sure pid size is [batch_size, 1]
12
+ pid_dist = pid - pid.t()
13
+ labels = (pid_dist == 0).float()
14
+
15
+ if image_id != None:
16
+ # print("Mix PID and ImageID to create soft label.")
17
+ image_id = image_id.reshape((-1, 1))
18
+ image_id_dist = image_id - image_id.t()
19
+ image_id_mask = (image_id_dist == 0).float()
20
+ labels = (labels - image_id_mask) * factor + image_id_mask
21
+ # labels = (labels + image_id_mask) / 2
22
+
23
+ image_norm = image_fetures / image_fetures.norm(dim=1, keepdim=True)
24
+ text_norm = text_fetures / text_fetures.norm(dim=1, keepdim=True)
25
+
26
+ t2i_cosine_theta = text_norm @ image_norm.t()
27
+ i2t_cosine_theta = t2i_cosine_theta.t()
28
+
29
+ text_proj_image = logit_scale * t2i_cosine_theta
30
+ image_proj_text = logit_scale * i2t_cosine_theta
31
+
32
+ # normalize the true matching distribution
33
+ labels_distribute = labels / labels.sum(dim=1)
34
+
35
+ i2t_pred = F.softmax(image_proj_text, dim=1)
36
+ i2t_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - torch.log(labels_distribute + epsilon))
37
+ t2i_pred = F.softmax(text_proj_image, dim=1)
38
+ t2i_loss = t2i_pred * (F.log_softmax(text_proj_image, dim=1) - torch.log(labels_distribute + epsilon))
39
+
40
+ loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1))
41
+
42
+ return loss
43
+
44
+
45
+ def compute_mlm(scores, labels):
46
+ ce = nn.CrossEntropyLoss(ignore_index=0)
47
+ return ce(scores, labels)
48
+
49
+
50
+ def compute_itc(image_features, text_features, logit_scale):
51
+ """
52
+ image-text contrastive (ITC) loss, InfoNCE
53
+ """
54
+ batch_size = image_features.shape[0]
55
+ labels = torch.arange(start=0, end=batch_size, dtype=torch.int64)
56
+ labels = labels.to(image_features.device)
57
+
58
+
59
+ # normalized features
60
+ image_norm = image_features / image_features.norm(dim=-1, keepdim=True)
61
+ text_norm = text_features / text_features.norm(dim=-1, keepdim=True)
62
+
63
+ # cosine similarity as logits
64
+ logits_per_image = logit_scale * image_norm @ text_norm.t()
65
+ logits_per_text = logits_per_image.t()
66
+
67
+ loss_i = F.cross_entropy(logits_per_image, labels)
68
+ loss_t =F.cross_entropy(logits_per_text, labels)
69
+ loss = (loss_i + loss_t)/2
70
+
71
+ return loss
72
+
73
+
74
+ def compute_id(image_logits, text_logits, labels):
75
+ """
76
+ Instance loss proposed at http://arxiv.org/abs/1711.05535
77
+ """
78
+ criterion = nn.CrossEntropyLoss(reduction="mean")
79
+
80
+ loss = criterion(image_logits, labels) + criterion(text_logits, labels)
81
+
82
+ return loss / 2
83
+
84
+
85
+ def compute_cmpm(image_embeddings, text_embeddings, labels, epsilon=1e-8):
86
+ """
87
+ Cross-Modal Projection Matching Loss(CMPM)
88
+ :param image_embeddings: Tensor with dtype torch.float32
89
+ :param text_embeddings: Tensor with dtype torch.float32
90
+ :param labels: Tensor with dtype torch.int32
91
+ :return:
92
+ i2t_loss: cmpm loss for image projected to text
93
+ t2i_loss: cmpm loss for text projected to image
94
+ pos_avg_sim: average cosine-similarity for positive pairs
95
+ neg_avg_sim: averate cosine-similarity for negative pairs
96
+ """
97
+
98
+ batch_size = image_embeddings.shape[0]
99
+ labels_reshape = torch.reshape(labels, (batch_size, 1))
100
+ labels_dist = labels_reshape - labels_reshape.t()
101
+ labels_mask = (labels_dist == 0).float()
102
+
103
+ image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True)
104
+ text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True)
105
+ image_proj_text = torch.matmul(image_embeddings, text_norm.t())
106
+ text_proj_image = torch.matmul(text_embeddings, image_norm.t())
107
+
108
+ # normalize the true matching distribution
109
+ labels_mask_norm = labels_mask / labels_mask.norm(dim=1)
110
+
111
+ i2t_pred = F.softmax(image_proj_text, dim=1)
112
+ i2t_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - torch.log(labels_mask_norm + epsilon))
113
+ t2i_pred = F.softmax(text_proj_image, dim=1)
114
+ t2i_loss = t2i_pred * (F.log_softmax(text_proj_image, dim=1) - torch.log(labels_mask_norm + epsilon))
115
+
116
+ cmpm_loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1))
117
+
118
+ return cmpm_loss
119
+
lib/IRRA/tokenizer.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+ import torch
10
+
11
+ @lru_cache()
12
+ def default_bpe():
13
+ return "./model/bpe_simple_vocab_16e6.txt.gz"
14
+
15
+
16
+ @lru_cache()
17
+ def bytes_to_unicode():
18
+ """
19
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
20
+ The reversible bpe codes work on unicode strings.
21
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
22
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
23
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
24
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
25
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
26
+ """
27
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
28
+ cs = bs[:]
29
+ n = 0
30
+ for b in range(2**8):
31
+ if b not in bs:
32
+ bs.append(b)
33
+ cs.append(2**8+n)
34
+ n += 1
35
+ cs = [chr(n) for n in cs]
36
+ return dict(zip(bs, cs))
37
+
38
+
39
+ def get_pairs(word):
40
+ """Return set of symbol pairs in a word.
41
+ Word is represented as tuple of symbols (symbols being variable-length strings).
42
+ """
43
+ pairs = set()
44
+ prev_char = word[0]
45
+ for char in word[1:]:
46
+ pairs.add((prev_char, char))
47
+ prev_char = char
48
+ return pairs
49
+
50
+
51
+ def basic_clean(text):
52
+ text = ftfy.fix_text(text)
53
+ text = html.unescape(html.unescape(text))
54
+ return text.strip()
55
+
56
+
57
+ def whitespace_clean(text):
58
+ text = re.sub(r'\s+', ' ', text)
59
+ text = text.strip()
60
+ return text
61
+
62
+
63
+ class SimpleTokenizer(object):
64
+ def __init__(self, bpe_path: str = default_bpe()):
65
+ self.byte_encoder = bytes_to_unicode()
66
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
67
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
68
+ merges = merges[1:49152-256-2+1]
69
+ merges = [tuple(merge.split()) for merge in merges]
70
+ vocab = list(bytes_to_unicode().values())
71
+ vocab = vocab + [v+'</w>' for v in vocab]
72
+ for merge in merges:
73
+ vocab.append(''.join(merge))
74
+
75
+ vocab.pop(-1) # remove last one in vocab(jekyll) to keep vocab_size unchanged
76
+ vocab.extend(['<|mask|>', '<|startoftext|>', '<|endoftext|>']) # vocab_size 49408
77
+ # vocab.extend(['<|startoftext|>', '<|endoftext|>']) # vocab_size 49408
78
+ self.encoder = dict(zip(vocab, range(len(vocab))))
79
+ self.decoder = {v: k for k, v in self.encoder.items()}
80
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
81
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|mask|>': '<|mask|>', '<|endoftext|>': '<|endoftext|>'}
82
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|mask\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
83
+
84
+ def bpe(self, token):
85
+ if token in self.cache:
86
+ return self.cache[token]
87
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
88
+ pairs = get_pairs(word)
89
+
90
+ if not pairs:
91
+ return token+'</w>'
92
+
93
+ while True:
94
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
95
+ if bigram not in self.bpe_ranks:
96
+ break
97
+ first, second = bigram
98
+ new_word = []
99
+ i = 0
100
+ while i < len(word):
101
+ try:
102
+ j = word.index(first, i)
103
+ new_word.extend(word[i:j])
104
+ i = j
105
+ except:
106
+ new_word.extend(word[i:])
107
+ break
108
+
109
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
110
+ new_word.append(first+second)
111
+ i += 2
112
+ else:
113
+ new_word.append(word[i])
114
+ i += 1
115
+ new_word = tuple(new_word)
116
+ word = new_word
117
+ if len(word) == 1:
118
+ break
119
+ else:
120
+ pairs = get_pairs(word)
121
+ word = ' '.join(word)
122
+ self.cache[token] = word
123
+ return word
124
+
125
+ def encode(self, text):
126
+ bpe_tokens = []
127
+ text = whitespace_clean(basic_clean(text)).lower()
128
+ for token in re.findall(self.pat, text):
129
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
130
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
131
+ return bpe_tokens
132
+
133
+ def decode(self, tokens):
134
+ text = ''.join([self.decoder[token] for token in tokens])
135
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
136
+ return text
137
+
138
+ def tokenize(caption: str, tokenizer, text_length=77, truncate=True) -> torch.LongTensor:
139
+ sot_token = tokenizer.encoder["<|startoftext|>"]
140
+ eot_token = tokenizer.encoder["<|endoftext|>"]
141
+ tokens = [sot_token] + tokenizer.encode(caption) + [eot_token]
142
+
143
+ result = torch.zeros(text_length, dtype=torch.long)
144
+ if len(tokens) > text_length:
145
+ if truncate:
146
+ tokens = tokens[:text_length]
147
+ tokens[-1] = eot_token
148
+ else:
149
+ raise RuntimeError(
150
+ f"Input {caption} is too long for context length {text_length}"
151
+ )
152
+ result[:len(tokens)] = torch.tensor(tokens)
153
+ return result # type: ignore
lib/__init__.py ADDED
File without changes
lib/components/__init__.py ADDED
File without changes
lib/utils/model.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import yaml
3
+ import torch
4
+
5
+ from lib.IRRA.tokenizer import tokenize, SimpleTokenizer
6
+ from lib.IRRA.image import prepare_images
7
+ from lib.IRRA.model.build import build_model, IRRA
8
+
9
+ from easydict import EasyDict
10
+
11
+ @st.cache_resource
12
+ def get_model():
13
+ args = yaml.load(open('model/configs.yaml'), Loader=yaml.FullLoader)
14
+ args = EasyDict(args)
15
+ args['training'] = False
16
+
17
+ model = build_model(args)
18
+
19
+ return model
20
+
21
+ def get_similarities(text: str, images: list[str], model: IRRA) -> torch.Tensor:
22
+ tokenizer = SimpleTokenizer()
23
+
24
+ txt = tokenize(text, tokenizer)
25
+ imgs = prepare_images(images)
26
+
27
+ print(imgs.shape)
28
+ image_feats = model.encode_image(imgs)
29
+ text_feats = model.encode_text(txt.unsqueeze(0))
30
+
31
+ return text_feats @ image_feats.t()