xinyu1205 commited on
Commit
6d1a894
1 Parent(s): 12fa8ea

Update models/tag2text.py

Browse files
Files changed (1) hide show
  1. models/tag2text.py +3 -0
models/tag2text.py CHANGED
@@ -25,6 +25,8 @@ import numpy as np
25
  def read_json(rpath):
26
  with open(rpath, 'r') as f:
27
  return json.load(f)
 
 
28
 
29
  class Tag2Text_Caption(nn.Module):
30
  def __init__(self,
@@ -132,6 +134,7 @@ class Tag2Text_Caption(nn.Module):
132
  targets = torch.where(torch.sigmoid(logits) > self.threshold , torch.tensor(1.0).to(image.device), torch.zeros(self.num_class).to(image.device))
133
 
134
  tag = targets.cpu().numpy()
 
135
  bs = image.size(0)
136
  tag_input = []
137
  for b in range(bs):
 
25
  def read_json(rpath):
26
  with open(rpath, 'r') as f:
27
  return json.load(f)
28
+
29
+ delete_tag_index = [135]
30
 
31
  class Tag2Text_Caption(nn.Module):
32
  def __init__(self,
 
134
  targets = torch.where(torch.sigmoid(logits) > self.threshold , torch.tensor(1.0).to(image.device), torch.zeros(self.num_class).to(image.device))
135
 
136
  tag = targets.cpu().numpy()
137
+ tag[:,delete_tag_index] = 0
138
  bs = image.size(0)
139
  tag_input = []
140
  for b in range(bs):