hsshin98
commited on
Commit
•
dfe1f0b
1
Parent(s):
ed81860
cpu
Browse files
app.py
CHANGED
@@ -4,7 +4,6 @@ import argparse
|
|
4 |
import glob
|
5 |
import multiprocessing as mp
|
6 |
import os
|
7 |
-
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
|
8 |
|
9 |
# fmt: off
|
10 |
import sys
|
@@ -40,6 +39,7 @@ def setup_cfg(args):
|
|
40 |
add_cat_seg_config(cfg)
|
41 |
cfg.merge_from_file(args.config_file)
|
42 |
cfg.merge_from_list(args.opts)
|
|
|
43 |
cfg.freeze()
|
44 |
return cfg
|
45 |
|
|
|
4 |
import glob
|
5 |
import multiprocessing as mp
|
6 |
import os
|
|
|
7 |
|
8 |
# fmt: off
|
9 |
import sys
|
|
|
39 |
add_cat_seg_config(cfg)
|
40 |
cfg.merge_from_file(args.config_file)
|
41 |
cfg.merge_from_list(args.opts)
|
42 |
+
cfg.MODEL.DEVICE = "cpu"
|
43 |
cfg.freeze()
|
44 |
return cfg
|
45 |
|
cat_seg/modeling/transformer/cat_seg_predictor.py
CHANGED
@@ -58,7 +58,7 @@ class CATSegPredictor(nn.Module):
|
|
58 |
if self.test_class_texts == None:
|
59 |
self.test_class_texts = self.class_texts
|
60 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
61 |
-
|
62 |
self.tokenizer = None
|
63 |
if clip_pretrained == "ViT-G" or clip_pretrained == "ViT-H":
|
64 |
# for OpenCLIP models
|
@@ -84,12 +84,12 @@ class CATSegPredictor(nn.Module):
|
|
84 |
prompt_templates = ['A photo of a {} in the scene',]
|
85 |
else:
|
86 |
raise NotImplementedError
|
87 |
-
|
88 |
-
self.text_features = self.class_embeddings(self.class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
|
89 |
-
self.text_features_test = self.class_embeddings(self.test_class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
|
90 |
|
91 |
self.clip_model = clip_model.float()
|
92 |
self.clip_preprocess = clip_preprocess
|
|
|
|
|
|
|
93 |
|
94 |
transformer = Aggregator(
|
95 |
text_guidance_dim=text_guidance_dim,
|
@@ -161,9 +161,9 @@ class CATSegPredictor(nn.Module):
|
|
161 |
else:
|
162 |
texts = [template.format(classname) for template in templates] # format with class
|
163 |
if self.tokenizer is not None:
|
164 |
-
texts = self.tokenizer(texts).
|
165 |
else:
|
166 |
-
texts = clip.tokenize(texts).
|
167 |
class_embeddings = clip_model.encode_text(texts)
|
168 |
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
169 |
if len(templates) != class_embeddings.shape[0]:
|
@@ -171,5 +171,5 @@ class CATSegPredictor(nn.Module):
|
|
171 |
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
172 |
class_embedding = class_embeddings
|
173 |
zeroshot_weights.append(class_embedding)
|
174 |
-
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).
|
175 |
return zeroshot_weights
|
|
|
58 |
if self.test_class_texts == None:
|
59 |
self.test_class_texts = self.class_texts
|
60 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
61 |
+
self.device = device
|
62 |
self.tokenizer = None
|
63 |
if clip_pretrained == "ViT-G" or clip_pretrained == "ViT-H":
|
64 |
# for OpenCLIP models
|
|
|
84 |
prompt_templates = ['A photo of a {} in the scene',]
|
85 |
else:
|
86 |
raise NotImplementedError
|
|
|
|
|
|
|
87 |
|
88 |
self.clip_model = clip_model.float()
|
89 |
self.clip_preprocess = clip_preprocess
|
90 |
+
|
91 |
+
self.text_features = self.class_embeddings(self.class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
|
92 |
+
self.text_features_test = self.class_embeddings(self.test_class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
|
93 |
|
94 |
transformer = Aggregator(
|
95 |
text_guidance_dim=text_guidance_dim,
|
|
|
161 |
else:
|
162 |
texts = [template.format(classname) for template in templates] # format with class
|
163 |
if self.tokenizer is not None:
|
164 |
+
texts = self.tokenizer(texts).to(self.device)
|
165 |
else:
|
166 |
+
texts = clip.tokenize(texts).to(self.device)
|
167 |
class_embeddings = clip_model.encode_text(texts)
|
168 |
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
169 |
if len(templates) != class_embeddings.shape[0]:
|
|
|
171 |
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
172 |
class_embedding = class_embeddings
|
173 |
zeroshot_weights.append(class_embedding)
|
174 |
+
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(self.device)
|
175 |
return zeroshot_weights
|