Spaces:
Runtime error
Runtime error
Update models/tag2text.py
Browse files- models/tag2text.py +14 -2
models/tag2text.py
CHANGED
@@ -26,7 +26,14 @@ def read_json(rpath):
|
|
26 |
with open(rpath, 'r') as f:
|
27 |
return json.load(f)
|
28 |
|
|
|
|
|
29 |
delete_tag_index = [127,2961, 3351, 3265, 3338, 3355, 3359]
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
class Tag2Text_Caption(nn.Module):
|
32 |
def __init__(self,
|
@@ -36,7 +43,7 @@ class Tag2Text_Caption(nn.Module):
|
|
36 |
vit_grad_ckpt = False,
|
37 |
vit_ckpt_layer = 0,
|
38 |
prompt = 'a picture of ',
|
39 |
-
threshold = 0.
|
40 |
):
|
41 |
"""
|
42 |
Args:
|
@@ -105,6 +112,10 @@ class Tag2Text_Caption(nn.Module):
|
|
105 |
tie_encoder_decoder_weights(self.tag_encoder,self.vision_multi,'',' ')
|
106 |
self.tag_array = tra_array
|
107 |
|
|
|
|
|
|
|
|
|
108 |
def del_selfattention(self):
|
109 |
del self.vision_multi.embeddings
|
110 |
for layer in self.vision_multi.encoder.layer:
|
@@ -130,7 +141,8 @@ class Tag2Text_Caption(nn.Module):
|
|
130 |
|
131 |
logits = self.fc(mlr_tagembedding[0])
|
132 |
|
133 |
-
targets = torch.where(torch.sigmoid(logits) > self.threshold , torch.tensor(1.0).to(image.device), torch.zeros(self.num_class).to(image.device))
|
|
|
134 |
|
135 |
tag = targets.cpu().numpy()
|
136 |
tag[:,delete_tag_index] = 0
|
|
|
26 |
with open(rpath, 'r') as f:
|
27 |
return json.load(f)
|
28 |
|
29 |
+
# delete some tags that may disturb captioning
|
30 |
+
# 127: "quarter"; 2961: "back"; 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
|
31 |
delete_tag_index = [127,2961, 3351, 3265, 3338, 3355, 3359]
|
32 |
+
|
33 |
+
# adjust thresholds for some tags
|
34 |
+
# default threshold: 0.68
|
35 |
+
# 2701: "person"; 2828: "man"; 1167: "woman";
|
36 |
+
tag_thrshold = {2701:0.7, 2828: 0.7, 1167: 0.7}
|
37 |
|
38 |
class Tag2Text_Caption(nn.Module):
|
39 |
def __init__(self,
|
|
|
43 |
vit_grad_ckpt = False,
|
44 |
vit_ckpt_layer = 0,
|
45 |
prompt = 'a picture of ',
|
46 |
+
threshold = 0.68,
|
47 |
):
|
48 |
"""
|
49 |
Args:
|
|
|
112 |
tie_encoder_decoder_weights(self.tag_encoder,self.vision_multi,'',' ')
|
113 |
self.tag_array = tra_array
|
114 |
|
115 |
+
self.class_threshold = torch.ones(self.num_class) * self.threshold
|
116 |
+
for key,value in tag_thrshold.items():
|
117 |
+
self.class_threshold[key] = value
|
118 |
+
|
119 |
def del_selfattention(self):
|
120 |
del self.vision_multi.embeddings
|
121 |
for layer in self.vision_multi.encoder.layer:
|
|
|
141 |
|
142 |
logits = self.fc(mlr_tagembedding[0])
|
143 |
|
144 |
+
# targets = torch.where(torch.sigmoid(logits) > self.threshold , torch.tensor(1.0).to(image.device), torch.zeros(self.num_class).to(image.device))
|
145 |
+
targets = torch.where(torch.sigmoid(logits) > self.class_threshold.to(image.device) , torch.tensor(1.0).to(image.device), torch.zeros(self.num_class).to(image.device))
|
146 |
|
147 |
tag = targets.cpu().numpy()
|
148 |
tag[:,delete_tag_index] = 0
|