hsshin98 commited on
Commit
dfe1f0b
1 Parent(s): ed81860
app.py CHANGED
@@ -4,7 +4,6 @@ import argparse
4
  import glob
5
  import multiprocessing as mp
6
  import os
7
- os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
8
 
9
  # fmt: off
10
  import sys
@@ -40,6 +39,7 @@ def setup_cfg(args):
40
  add_cat_seg_config(cfg)
41
  cfg.merge_from_file(args.config_file)
42
  cfg.merge_from_list(args.opts)
 
43
  cfg.freeze()
44
  return cfg
45
 
 
4
  import glob
5
  import multiprocessing as mp
6
  import os
 
7
 
8
  # fmt: off
9
  import sys
 
39
  add_cat_seg_config(cfg)
40
  cfg.merge_from_file(args.config_file)
41
  cfg.merge_from_list(args.opts)
42
+ cfg.MODEL.DEVICE = "cpu"
43
  cfg.freeze()
44
  return cfg
45
 
cat_seg/modeling/transformer/cat_seg_predictor.py CHANGED
@@ -58,7 +58,7 @@ class CATSegPredictor(nn.Module):
58
  if self.test_class_texts == None:
59
  self.test_class_texts = self.class_texts
60
  device = "cuda" if torch.cuda.is_available() else "cpu"
61
-
62
  self.tokenizer = None
63
  if clip_pretrained == "ViT-G" or clip_pretrained == "ViT-H":
64
  # for OpenCLIP models
@@ -84,12 +84,12 @@ class CATSegPredictor(nn.Module):
84
  prompt_templates = ['A photo of a {} in the scene',]
85
  else:
86
  raise NotImplementedError
87
-
88
- self.text_features = self.class_embeddings(self.class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
89
- self.text_features_test = self.class_embeddings(self.test_class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
90
 
91
  self.clip_model = clip_model.float()
92
  self.clip_preprocess = clip_preprocess
 
 
 
93
 
94
  transformer = Aggregator(
95
  text_guidance_dim=text_guidance_dim,
@@ -161,9 +161,9 @@ class CATSegPredictor(nn.Module):
161
  else:
162
  texts = [template.format(classname) for template in templates] # format with class
163
  if self.tokenizer is not None:
164
- texts = self.tokenizer(texts).cuda()
165
  else:
166
- texts = clip.tokenize(texts).cuda()
167
  class_embeddings = clip_model.encode_text(texts)
168
  class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
169
  if len(templates) != class_embeddings.shape[0]:
@@ -171,5 +171,5 @@ class CATSegPredictor(nn.Module):
171
  class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
172
  class_embedding = class_embeddings
173
  zeroshot_weights.append(class_embedding)
174
- zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
175
  return zeroshot_weights
 
58
  if self.test_class_texts == None:
59
  self.test_class_texts = self.class_texts
60
  device = "cuda" if torch.cuda.is_available() else "cpu"
61
+ self.device = device
62
  self.tokenizer = None
63
  if clip_pretrained == "ViT-G" or clip_pretrained == "ViT-H":
64
  # for OpenCLIP models
 
84
  prompt_templates = ['A photo of a {} in the scene',]
85
  else:
86
  raise NotImplementedError
 
 
 
87
 
88
  self.clip_model = clip_model.float()
89
  self.clip_preprocess = clip_preprocess
90
+
91
+ self.text_features = self.class_embeddings(self.class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
92
+ self.text_features_test = self.class_embeddings(self.test_class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
93
 
94
  transformer = Aggregator(
95
  text_guidance_dim=text_guidance_dim,
 
161
  else:
162
  texts = [template.format(classname) for template in templates] # format with class
163
  if self.tokenizer is not None:
164
+ texts = self.tokenizer(texts).to(self.device)
165
  else:
166
+ texts = clip.tokenize(texts).to(self.device)
167
  class_embeddings = clip_model.encode_text(texts)
168
  class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
169
  if len(templates) != class_embeddings.shape[0]:
 
171
  class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
172
  class_embedding = class_embeddings
173
  zeroshot_weights.append(class_embedding)
174
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(self.device)
175
  return zeroshot_weights