Zilun commited on
Commit
b0c0186
1 Parent(s): 6e32abc

Upload 8 files

Browse files
ckpt/RS5M_ViT-B-32.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:129bafaa6a097b8be52e2babf27d24f0a934dae919201e538dc698611bd1ea01
3
+ size 605222594
codebase/inference/classname_and_prompt/RSAID.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # templates = [
2
+ # 'a centered satellite photo of {}.',
3
+ # 'a centered satellite photo of a {}.',
4
+ # 'a centered satellite photo of the {}.',
5
+ # ]
6
+
7
+
8
+ templates = [
9
+ 'a remote sensing image of many {}.',
10
+ 'a remote sensing image of a {}.',
11
+ 'a remote sensing image of the {}.',
12
+ 'a remote sensing image of the hard to see {}.',
13
+ 'a remote sensing image of a hard to see {}.',
14
+ 'a low resolution remote sensing image of the {}.',
15
+ 'a low resolution remote sensing image of a {}.',
16
+ 'a bad remote sensing image of the {}.',
17
+ 'a bad remote sensing image of a {}.',
18
+ 'a cropped remote sensing image of the {}.',
19
+ 'a cropped remote sensing image of a {}.',
20
+ 'a bright remote sensing image of the {}.',
21
+ 'a bright remote sensing image of a {}.',
22
+ 'a dark remote sensing image of the {}.',
23
+ 'a dark remote sensing image of a {}.',
24
+ 'a close-up remote sensing image of the {}.',
25
+ 'a close-up remote sensing image of a {}.',
26
+ 'a black and white remote sensing image of the {}.',
27
+ 'a black and white remote sensing image of a {}.',
28
+ 'a jpeg corrupted remote sensing image of the {}.',
29
+ 'a jpeg corrupted remote sensing image of a {}.',
30
+ 'a blurry remote sensing image of the {}.',
31
+ 'a blurry remote sensing image of a {}.',
32
+ 'a good remote sensing image of the {}.',
33
+ 'a good remote sensing image of a {}.',
34
+ 'a remote sensing image of the large {}.',
35
+ 'a remote sensing image of a large {}.',
36
+ 'a remote sensing image of the nice {}.',
37
+ 'a remote sensing image of a nice {}.',
38
+ 'a remote sensing image of the small {}.',
39
+ 'a remote sensing image of a small {}.',
40
+ 'a remote sensing image of the weird {}.',
41
+ 'a remote sensing image of a weird {}.',
42
+ 'a remote sensing image of the cool {}.',
43
+ 'a remote sensing image of a cool {}.',
44
+ 'an aerial image of many {}.',
45
+ 'an aerial image of a {}.',
46
+ 'an aerial image of the {}.',
47
+ 'an aerial image of the hard to see {}.',
48
+ 'an aerial image of a hard to see {}.',
49
+ 'a low resolution aerial image of the {}.',
50
+ 'a low resolution aerial image of a {}.',
51
+ 'a bad aerial image of the {}.',
52
+ 'a bad aerial image of a {}.',
53
+ 'a cropped aerial image of the {}.',
54
+ 'a cropped aerial image of a {}.',
55
+ 'a bright aerial image of the {}.',
56
+ 'a bright aerial image of a {}.',
57
+ 'a dark aerial image of the {}.',
58
+ 'a dark aerial image of a {}.',
59
+ 'a close-up aerial image of the {}.',
60
+ 'a close-up aerial image of a {}.',
61
+ 'a black and white aerial image of the {}.',
62
+ 'a black and white aerial image of a {}.',
63
+ 'a jpeg corrupted aerial image of the {}.',
64
+ 'a jpeg corrupted aerial image of a {}.',
65
+ 'a blurry aerial image of the {}.',
66
+ 'a blurry aerial image of a {}.',
67
+ 'a good aerial image of the {}.',
68
+ 'a good aerial image of a {}.',
69
+ 'an aerial image of the large {}.',
70
+ 'an aerial image of a large {}.',
71
+ 'an aerial image of the nice {}.',
72
+ 'an aerial image of a nice {}.',
73
+ 'an aerial image of the small {}.',
74
+ 'an aerial image of a small {}.',
75
+ 'an aerial image of the weird {}.',
76
+ 'an aerial image of a weird {}.',
77
+ 'an aerial image of the cool {}.',
78
+ 'an aerial image of a cool {}.',
79
+ 'a satellite image of many {}.',
80
+ 'a satellite image of a {}.',
81
+ 'a satellite image of the {}.',
82
+ 'a satellite image of the hard to see {}.',
83
+ 'a satellite image of a hard to see {}.',
84
+ 'a low resolution satellite image of the {}.',
85
+ 'a low resolution satellite image of a {}.',
86
+ 'a bad satellite image of the {}.',
87
+ 'a bad satellite image of a {}.',
88
+ 'a cropped satellite image of the {}.',
89
+ 'a cropped satellite image of a {}.',
90
+ 'a bright satellite image of the {}.',
91
+ 'a bright satellite image of a {}.',
92
+ 'a dark satellite image of the {}.',
93
+ 'a dark satellite image of a {}.',
94
+ 'a close-up satellite image of the {}.',
95
+ 'a close-up satellite image of a {}.',
96
+ 'a black and white satellite image of the {}.',
97
+ 'a black and white satellite image of a {}.',
98
+ 'a jpeg corrupted satellite image of the {}.',
99
+ 'a jpeg corrupted satellite image of a {}.',
100
+ 'a blurry satellite image of the {}.',
101
+ 'a blurry satellite image of a {}.',
102
+ 'a good satellite image of the {}.',
103
+ 'a good satellite image of a {}.',
104
+ 'a satellite image of the large {}.',
105
+ 'a satellite image of a large {}.',
106
+ 'a satellite image of the nice {}.',
107
+ 'a satellite image of a nice {}.',
108
+ 'a satellite image of the small {}.',
109
+ 'a satellite image of a small {}.',
110
+ 'a satellite image of the weird {}.',
111
+ 'a satellite image of a weird {}.',
112
+ 'a satellite image of the cool {}.',
113
+ 'a satellite image of a cool {}.',
114
+ ]
codebase/inference/classname_and_prompt/RSEuroSAT.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # classes = [
2
+ # 'forest',
3
+ # 'permanent crop land',
4
+ # 'residential buildings or homes or apartments',
5
+ # 'river',
6
+ # 'pasture land',
7
+ # 'lake or sea',
8
+ # 'brushland or shrubland',
9
+ # 'annual crop land',
10
+ # 'industrial buildings or commercial buildings',
11
+ # 'highway or road',
12
+ # ]
13
+ # ['River', 'AnnualCrop', 'HerbaceousVegetation', 'Industrial', 'Residential', 'Highway', 'Pasture', 'Forest', 'SeaLake', 'PermanentCrop']
14
+
15
+ # classes = [
16
+ # 'river',
17
+ # 'annual crop land',
18
+ # 'brushland or shrubland',
19
+ # 'industrial buildings or commercial buildings',
20
+ # 'residential buildings or homes or apartments',
21
+ # 'highway or road',
22
+ # 'pasture land',
23
+ # 'forest',
24
+ # 'lake or sea',
25
+ # 'permanent crop land',
26
+ # ]
27
+
28
+ # templates = [
29
+ # 'a centered satellite photo of {}.',
30
+ # 'a centered satellite photo of a {}.',
31
+ # 'a centered satellite photo of the {}.',
32
+ # ]
33
+
34
+ templates = [
35
+ 'a remote sensing image of many {}.',
36
+ 'a remote sensing image of a {}.',
37
+ 'a remote sensing image of the {}.',
38
+ 'a remote sensing image of the hard to see {}.',
39
+ 'a remote sensing image of a hard to see {}.',
40
+ 'a low resolution remote sensing image of the {}.',
41
+ 'a low resolution remote sensing image of a {}.',
42
+ 'a bad remote sensing image of the {}.',
43
+ 'a bad remote sensing image of a {}.',
44
+ 'a cropped remote sensing image of the {}.',
45
+ 'a cropped remote sensing image of a {}.',
46
+ 'a bright remote sensing image of the {}.',
47
+ 'a bright remote sensing image of a {}.',
48
+ 'a dark remote sensing image of the {}.',
49
+ 'a dark remote sensing image of a {}.',
50
+ 'a close-up remote sensing image of the {}.',
51
+ 'a close-up remote sensing image of a {}.',
52
+ 'a black and white remote sensing image of the {}.',
53
+ 'a black and white remote sensing image of a {}.',
54
+ 'a jpeg corrupted remote sensing image of the {}.',
55
+ 'a jpeg corrupted remote sensing image of a {}.',
56
+ 'a blurry remote sensing image of the {}.',
57
+ 'a blurry remote sensing image of a {}.',
58
+ 'a good remote sensing image of the {}.',
59
+ 'a good remote sensing image of a {}.',
60
+ 'a remote sensing image of the large {}.',
61
+ 'a remote sensing image of a large {}.',
62
+ 'a remote sensing image of the nice {}.',
63
+ 'a remote sensing image of a nice {}.',
64
+ 'a remote sensing image of the small {}.',
65
+ 'a remote sensing image of a small {}.',
66
+ 'a remote sensing image of the weird {}.',
67
+ 'a remote sensing image of a weird {}.',
68
+ 'a remote sensing image of the cool {}.',
69
+ 'a remote sensing image of a cool {}.',
70
+ 'an aerial image of many {}.',
71
+ 'an aerial image of a {}.',
72
+ 'an aerial image of the {}.',
73
+ 'an aerial image of the hard to see {}.',
74
+ 'an aerial image of a hard to see {}.',
75
+ 'a low resolution aerial image of the {}.',
76
+ 'a low resolution aerial image of a {}.',
77
+ 'a bad aerial image of the {}.',
78
+ 'a bad aerial image of a {}.',
79
+ 'a cropped aerial image of the {}.',
80
+ 'a cropped aerial image of a {}.',
81
+ 'a bright aerial image of the {}.',
82
+ 'a bright aerial image of a {}.',
83
+ 'a dark aerial image of the {}.',
84
+ 'a dark aerial image of a {}.',
85
+ 'a close-up aerial image of the {}.',
86
+ 'a close-up aerial image of a {}.',
87
+ 'a black and white aerial image of the {}.',
88
+ 'a black and white aerial image of a {}.',
89
+ 'a jpeg corrupted aerial image of the {}.',
90
+ 'a jpeg corrupted aerial image of a {}.',
91
+ 'a blurry aerial image of the {}.',
92
+ 'a blurry aerial image of a {}.',
93
+ 'a good aerial image of the {}.',
94
+ 'a good aerial image of a {}.',
95
+ 'an aerial image of the large {}.',
96
+ 'an aerial image of a large {}.',
97
+ 'an aerial image of the nice {}.',
98
+ 'an aerial image of a nice {}.',
99
+ 'an aerial image of the small {}.',
100
+ 'an aerial image of a small {}.',
101
+ 'an aerial image of the weird {}.',
102
+ 'an aerial image of a weird {}.',
103
+ 'an aerial image of the cool {}.',
104
+ 'an aerial image of a cool {}.',
105
+ 'a satellite image of many {}.',
106
+ 'a satellite image of a {}.',
107
+ 'a satellite image of the {}.',
108
+ 'a satellite image of the hard to see {}.',
109
+ 'a satellite image of a hard to see {}.',
110
+ 'a low resolution satellite image of the {}.',
111
+ 'a low resolution satellite image of a {}.',
112
+ 'a bad satellite image of the {}.',
113
+ 'a bad satellite image of a {}.',
114
+ 'a cropped satellite image of the {}.',
115
+ 'a cropped satellite image of a {}.',
116
+ 'a bright satellite image of the {}.',
117
+ 'a bright satellite image of a {}.',
118
+ 'a dark satellite image of the {}.',
119
+ 'a dark satellite image of a {}.',
120
+ 'a close-up satellite image of the {}.',
121
+ 'a close-up satellite image of a {}.',
122
+ 'a black and white satellite image of the {}.',
123
+ 'a black and white satellite image of a {}.',
124
+ 'a jpeg corrupted satellite image of the {}.',
125
+ 'a jpeg corrupted satellite image of a {}.',
126
+ 'a blurry satellite image of the {}.',
127
+ 'a blurry satellite image of a {}.',
128
+ 'a good satellite image of the {}.',
129
+ 'a good satellite image of a {}.',
130
+ 'a satellite image of the large {}.',
131
+ 'a satellite image of a large {}.',
132
+ 'a satellite image of the nice {}.',
133
+ 'a satellite image of a nice {}.',
134
+ 'a satellite image of the small {}.',
135
+ 'a satellite image of a small {}.',
136
+ 'a satellite image of the weird {}.',
137
+ 'a satellite image of a weird {}.',
138
+ 'a satellite image of the cool {}.',
139
+ 'a satellite image of a cool {}.',
140
+ ]
codebase/inference/classname_and_prompt/RSRESISC45.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # templates = [
2
+ # 'a centered satellite photo of {}.',
3
+ # 'a centered satellite photo of a {}.',
4
+ # 'a centered satellite photo of the {}.',
5
+ # ]
6
+
7
+ templates = [
8
+ 'a remote sensing image of many {}.',
9
+ 'a remote sensing image of a {}.',
10
+ 'a remote sensing image of the {}.',
11
+ 'a remote sensing image of the hard to see {}.',
12
+ 'a remote sensing image of a hard to see {}.',
13
+ 'a low resolution remote sensing image of the {}.',
14
+ 'a low resolution remote sensing image of a {}.',
15
+ 'a bad remote sensing image of the {}.',
16
+ 'a bad remote sensing image of a {}.',
17
+ 'a cropped remote sensing image of the {}.',
18
+ 'a cropped remote sensing image of a {}.',
19
+ 'a bright remote sensing image of the {}.',
20
+ 'a bright remote sensing image of a {}.',
21
+ 'a dark remote sensing image of the {}.',
22
+ 'a dark remote sensing image of a {}.',
23
+ 'a close-up remote sensing image of the {}.',
24
+ 'a close-up remote sensing image of a {}.',
25
+ 'a black and white remote sensing image of the {}.',
26
+ 'a black and white remote sensing image of a {}.',
27
+ 'a jpeg corrupted remote sensing image of the {}.',
28
+ 'a jpeg corrupted remote sensing image of a {}.',
29
+ 'a blurry remote sensing image of the {}.',
30
+ 'a blurry remote sensing image of a {}.',
31
+ 'a good remote sensing image of the {}.',
32
+ 'a good remote sensing image of a {}.',
33
+ 'a remote sensing image of the large {}.',
34
+ 'a remote sensing image of a large {}.',
35
+ 'a remote sensing image of the nice {}.',
36
+ 'a remote sensing image of a nice {}.',
37
+ 'a remote sensing image of the small {}.',
38
+ 'a remote sensing image of a small {}.',
39
+ 'a remote sensing image of the weird {}.',
40
+ 'a remote sensing image of a weird {}.',
41
+ 'a remote sensing image of the cool {}.',
42
+ 'a remote sensing image of a cool {}.',
43
+ 'an aerial image of many {}.',
44
+ 'an aerial image of a {}.',
45
+ 'an aerial image of the {}.',
46
+ 'an aerial image of the hard to see {}.',
47
+ 'an aerial image of a hard to see {}.',
48
+ 'a low resolution aerial image of the {}.',
49
+ 'a low resolution aerial image of a {}.',
50
+ 'a bad aerial image of the {}.',
51
+ 'a bad aerial image of a {}.',
52
+ 'a cropped aerial image of the {}.',
53
+ 'a cropped aerial image of a {}.',
54
+ 'a bright aerial image of the {}.',
55
+ 'a bright aerial image of a {}.',
56
+ 'a dark aerial image of the {}.',
57
+ 'a dark aerial image of a {}.',
58
+ 'a close-up aerial image of the {}.',
59
+ 'a close-up aerial image of a {}.',
60
+ 'a black and white aerial image of the {}.',
61
+ 'a black and white aerial image of a {}.',
62
+ 'a jpeg corrupted aerial image of the {}.',
63
+ 'a jpeg corrupted aerial image of a {}.',
64
+ 'a blurry aerial image of the {}.',
65
+ 'a blurry aerial image of a {}.',
66
+ 'a good aerial image of the {}.',
67
+ 'a good aerial image of a {}.',
68
+ 'an aerial image of the large {}.',
69
+ 'an aerial image of a large {}.',
70
+ 'an aerial image of the nice {}.',
71
+ 'an aerial image of a nice {}.',
72
+ 'an aerial image of the small {}.',
73
+ 'an aerial image of a small {}.',
74
+ 'an aerial image of the weird {}.',
75
+ 'an aerial image of a weird {}.',
76
+ 'an aerial image of the cool {}.',
77
+ 'an aerial image of a cool {}.',
78
+ 'a satellite image of many {}.',
79
+ 'a satellite image of a {}.',
80
+ 'a satellite image of the {}.',
81
+ 'a satellite image of the hard to see {}.',
82
+ 'a satellite image of a hard to see {}.',
83
+ 'a low resolution satellite image of the {}.',
84
+ 'a low resolution satellite image of a {}.',
85
+ 'a bad satellite image of the {}.',
86
+ 'a bad satellite image of a {}.',
87
+ 'a cropped satellite image of the {}.',
88
+ 'a cropped satellite image of a {}.',
89
+ 'a bright satellite image of the {}.',
90
+ 'a bright satellite image of a {}.',
91
+ 'a dark satellite image of the {}.',
92
+ 'a dark satellite image of a {}.',
93
+ 'a close-up satellite image of the {}.',
94
+ 'a close-up satellite image of a {}.',
95
+ 'a black and white satellite image of the {}.',
96
+ 'a black and white satellite image of a {}.',
97
+ 'a jpeg corrupted satellite image of the {}.',
98
+ 'a jpeg corrupted satellite image of a {}.',
99
+ 'a blurry satellite image of the {}.',
100
+ 'a blurry satellite image of a {}.',
101
+ 'a good satellite image of the {}.',
102
+ 'a good satellite image of a {}.',
103
+ 'a satellite image of the large {}.',
104
+ 'a satellite image of a large {}.',
105
+ 'a satellite image of the nice {}.',
106
+ 'a satellite image of a nice {}.',
107
+ 'a satellite image of the small {}.',
108
+ 'a satellite image of a small {}.',
109
+ 'a satellite image of the weird {}.',
110
+ 'a satellite image of a weird {}.',
111
+ 'a satellite image of the cool {}.',
112
+ 'a satellite image of a cool {}.',
113
+ ]
codebase/inference/classname_and_prompt/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import RSEuroSAT
2
+ from . import RSAID
3
+ from . import RSRESISC45
codebase/inference/convert_weight.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import open_clip
3
+ import os
4
+
5
+
6
+ def main():
7
+ # trained_ckpt_path = "/home/zilun/RS5M_v5/ckpt/epoch_5.pt"
8
+ # model, _, _ = open_clip.create_model_and_transforms("ViT-B/32", pretrained="openai")
9
+
10
+ trained_ckpt_path = "/home/zilun/RS5M_v5/ckpt/epoch_2.pt"
11
+ model, _, _ = open_clip.create_model_and_transforms("ViT-H/14", pretrained="openclip")
12
+
13
+ checkpoint = torch.load(trained_ckpt_path, map_location="cpu")["state_dict"]
14
+ sd = {k: v for k, v in checkpoint.items()}
15
+ for key in list(sd.keys()):
16
+ if "text_backbone." in key:
17
+ sd[key.replace("text_backbone.", '')] = sd[key]
18
+ del sd[key]
19
+ if "image_backbone" in key:
20
+ sd[key.replace("image_backbone.", "visual.")] = sd[key]
21
+ del sd[key]
22
+
23
+ msg = model.load_state_dict(sd, strict=False)
24
+ print(msg)
25
+ print("loaded RSCLIP")
26
+
27
+ torch.save(
28
+ model.state_dict(),
29
+ os.path.join("/home/zilun/RS5M_v5/ckpt", "RS5M_ViT-B-32.pt"),
30
+ )
31
+
32
+
33
+ if __name__ == "__main__":
34
+ main()
codebase/inference/inference.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import open_clip
2
+ import torch
3
+ import os
4
+ import random
5
+ import numpy as np
6
+ import argparse
7
+ from inference_tool import (zeroshot_evaluation,
8
+ retrieval_evaluation,
9
+ semantic_localization_evaluation,
10
+ get_preprocess
11
+ )
12
+
13
+
14
+ def random_seed(seed):
15
+ torch.manual_seed(seed)
16
+ np.random.seed(seed)
17
+ torch.cuda.manual_seed_all(seed)
18
+ random.seed(seed)
19
+ torch.backends.cudnn.benchmark = True
20
+ torch.backends.cudnn.deterministic = False
21
+
22
+
23
+ def build_model(model_name, ckpt_path, device):
24
+ if model_name == "ViT-B-32":
25
+ model, _, _ = open_clip.create_model_and_transforms("ViT-B/32", pretrained="openai")
26
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
27
+ msg = model.load_state_dict(checkpoint, strict=False)
28
+
29
+ elif model_name == "ViT-H-14":
30
+ model, _, _ = open_clip.create_model_and_transforms("ViT-H/14", pretrained="openclip")
31
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
32
+ msg = model.load_state_dict(checkpoint, strict=False)
33
+
34
+ print(msg)
35
+ model = model.to(device)
36
+ print("loaded RSCLIP")
37
+
38
+ preprocess_val = get_preprocess(
39
+ image_resolution=224,
40
+ )
41
+
42
+ return model, preprocess_val
43
+
44
+
45
+ def evaluate(model, preprocess, args):
46
+ print("making val dataset with transformation: ")
47
+ print(preprocess)
48
+ zeroshot_datasets = [
49
+ 'EuroSAT',
50
+ 'RESISC45',
51
+ 'AID'
52
+ ]
53
+ selo_datasets = [
54
+ 'AIR-SLT'
55
+ ]
56
+
57
+ model.eval()
58
+ all_metrics = {}
59
+
60
+ # zeroshot classification
61
+ metrics = {}
62
+ for zeroshot_dataset in zeroshot_datasets:
63
+ zeroshot_metrics = zeroshot_evaluation(model, zeroshot_dataset, preprocess, args)
64
+ metrics.update(zeroshot_metrics)
65
+ all_metrics.update(zeroshot_metrics)
66
+ print(all_metrics)
67
+
68
+ # RSITMD
69
+ metrics = {}
70
+ retrieval_metrics_rsitmd = retrieval_evaluation(model, preprocess, args, recall_k_list=[1, 5, 10],
71
+ dataset_name="rsitmd")
72
+ metrics.update(retrieval_metrics_rsitmd)
73
+ all_metrics.update(retrieval_metrics_rsitmd)
74
+ print(all_metrics)
75
+
76
+ # RSICD
77
+ metrics = {}
78
+ retrieval_metrics_rsicd = retrieval_evaluation(model, preprocess, args, recall_k_list=[1, 5, 10],
79
+ dataset_name="rsicd")
80
+ metrics.update(retrieval_metrics_rsicd)
81
+ all_metrics.update(retrieval_metrics_rsicd)
82
+ print(all_metrics)
83
+
84
+ # selo_datasets
85
+ # Semantic Localization
86
+ metrics = {}
87
+ for selo_dataset in selo_datasets:
88
+ selo_metrics = semantic_localization_evaluation(model, selo_dataset, preprocess, args)
89
+ metrics.update(selo_metrics)
90
+ all_metrics.update(selo_metrics)
91
+ print(all_metrics)
92
+
93
+ return all_metrics
94
+
95
+
96
+ def main():
97
+ parser = argparse.ArgumentParser()
98
+ parser.add_argument(
99
+ "--model-name", default="ViT-B-32", type=str,
100
+ help="ViT-B-32 or ViT-H-14",
101
+ )
102
+ parser.add_argument(
103
+ "--ckpt-path", default="/home/zilun/RS5M_v5/ckpt/RS5M_ViT-B-32.pt", type=str,
104
+ help="Path to RS5M_ViT-B-32.pt",
105
+ )
106
+ parser.add_argument(
107
+ "--random-seed", default=3407, type=int,
108
+ help="random seed",
109
+ )
110
+ parser.add_argument(
111
+ "--test-dataset-dir", default="/home/zilun/RS5M_v5/data/rs5m_test_data", type=str,
112
+ help="test dataset dir",
113
+ )
114
+ parser.add_argument(
115
+ "--batch-size", default=500, type=int,
116
+ help="batch size",
117
+ )
118
+ parser.add_argument(
119
+ "--workers", default=8, type=int,
120
+ help="number of workers",
121
+ )
122
+ args = parser.parse_args()
123
+ args.device = "cuda" if torch.cuda.is_available() else "cpu"
124
+ print(args)
125
+ # random_seed(args.random_seed)
126
+
127
+ model, img_preprocess = build_model(args.model_name, args.ckpt_path, args.device)
128
+
129
+ eval_result = evaluate(model, img_preprocess, args)
130
+
131
+ for key, value in eval_result.items():
132
+ print("{}: {}".format(key, value))
133
+
134
+
135
+ if __name__ == "__main__":
136
+ main()
codebase/inference/inference_tool.py ADDED
@@ -0,0 +1,961 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import pdb
3
+ import tqdm
4
+ import numpy as np
5
+ import open_clip
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import os
9
+ from classname_and_prompt import *
10
+ from torchrs.datasets import AID, RESISC45, EuroSATRGB
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from PIL import Image
13
+ import pandas as pd
14
+ from clip_benchmark.datasets.builder import get_dataset_collate_fn
15
+ from clip_benchmark.metrics.zeroshot_retrieval import recall_at_k, batchify, dataloader_with_indices
16
+ from functools import reduce
17
+ import cv2
18
+ from scipy.ndimage import maximum_filter
19
+ from skimage import measure
20
+ import json
21
+ from datetime import datetime
22
+ from torchvision import transforms
23
+
24
+
25
+ def _convert_to_rgb(image):
26
+ return image.convert('RGB')
27
+
28
+
29
+ def get_preprocess(image_resolution=224, is_train=False, subset_name="clip", aug=None):
30
+
31
+ if subset_name == "clip":
32
+ normalize = transforms.Normalize(
33
+ mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
34
+ )
35
+ elif subset_name == "imagenet":
36
+ normalize = transforms.Normalize(
37
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
38
+ )
39
+
40
+ elif subset_name == "rs5m":
41
+ normalize = transforms.Normalize(
42
+ mean=[0.406, 0.423, 0.390], std=[0.188, 0.175, 0.185]
43
+ )
44
+
45
+ elif subset_name == "pub11":
46
+ normalize = transforms.Normalize(
47
+ mean=[0.445, 0.469, 0.441], std=[0.208, 0.193, 0.213]
48
+ )
49
+
50
+ elif subset_name == "rs3":
51
+ normalize = transforms.Normalize(
52
+ mean=[0.350, 0.356, 0.316], std=[0.158, 0.147, 0.143]
53
+ )
54
+
55
+ elif subset_name == "geometa":
56
+ normalize = transforms.Normalize(
57
+ mean=[0.320, 0.322, 0.285], std=[0.179, 0.168, 0.166]
58
+ )
59
+
60
+ if is_train:
61
+ preprocess_train = transforms.Compose([
62
+ transforms.RandomResizedCrop(
63
+ image_resolution,
64
+ interpolation=transforms.InterpolationMode.BICUBIC,
65
+ scale=(0.9, 1.0)
66
+ ),
67
+ _convert_to_rgb,
68
+ transforms.RandomHorizontalFlip(),
69
+ transforms.RandomRotation(degrees=(0, 360)),
70
+ transforms.ToTensor(),
71
+ normalize,
72
+ ])
73
+ return preprocess_train
74
+ else:
75
+ preprocess_val = transforms.Compose([
76
+ transforms.Resize(
77
+ size=image_resolution,
78
+ interpolation=transforms.InterpolationMode.BICUBIC,
79
+ ),
80
+ transforms.CenterCrop(image_resolution),
81
+ _convert_to_rgb,
82
+ transforms.ToTensor(),
83
+ normalize,
84
+ ])
85
+ return preprocess_val
86
+
87
+
88
+ def zeroshot_get_dataset(dataset_name, root, split, transform=None):
89
+
90
+ if dataset_name == "EuroSAT":
91
+ EuroSAT_root = os.path.join(root, "eurosat-rgb")
92
+ os.makedirs(EuroSAT_root, exist_ok=True)
93
+ dataset = EuroSATRGB(
94
+ root=EuroSAT_root,
95
+ transform=transform
96
+ )
97
+ dataset.classes = dataset.classes
98
+ dataset.templates = RSEuroSAT.templates
99
+
100
+ elif dataset_name == "AID":
101
+ AID_root = os.path.join(root, "AID")
102
+ os.makedirs(AID_root, exist_ok=True)
103
+ dataset = AID(
104
+ root=AID_root,
105
+ transform=transform
106
+ )
107
+ dataset.classes = dataset.classes
108
+ dataset.templates = RSAID.templates
109
+
110
+ elif dataset_name == "RESISC45":
111
+ RESISC45_root = os.path.join(root, "RESISC45")
112
+ os.makedirs(RESISC45_root, exist_ok=True)
113
+ dataset = RESISC45(
114
+ root=RESISC45_root,
115
+ transform=transform
116
+ )
117
+ dataset.classes = dataset.classes
118
+ dataset.templates = RSRESISC45.templates
119
+
120
+ dataset.classes = [dataset.classes[i].replace('_', ' ') for i in range(len(dataset.classes))]
121
+ dataset.classes = [dataset.classes[i].replace('/', ' ') for i in range(len(dataset.classes))]
122
+ dataset.classes = [dataset.classes[i].lower() for i in range(len(dataset.classes))]
123
+
124
+ return dataset
125
+
126
+
127
+ def zeroshot_classifier(model, classnames, templates, args):
128
+ tokenizer = open_clip.tokenize
129
+ with torch.no_grad():
130
+ zeroshot_weights = []
131
+ for classname in classnames:
132
+ texts = [template.replace('{}', classname) for template in templates]
133
+ context_length = 77
134
+ texts = tokenizer(texts, context_length=context_length).to(args.device)
135
+
136
+ class_embeddings = model.encode_text(texts)
137
+ class_embeddings = class_embeddings.mean(dim=0)
138
+ class_embedding = F.normalize(class_embeddings, dim=-1)
139
+ class_embedding /= class_embedding.norm()
140
+ zeroshot_weights.append(class_embedding.cpu())
141
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1)
142
+ return zeroshot_weights
143
+
144
+
145
+ def zeroshot_evaluation(model, zeroshot_dataset, preprocess, args):
146
+
147
+ dataset = zeroshot_get_dataset(dataset_name=zeroshot_dataset, split='test', root=args.test_dataset_dir, transform=preprocess)
148
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.workers)
149
+
150
+ logging.info(f'Calculating classifier for {zeroshot_dataset}')
151
+ classnames, prompt_templates = dataset.classes, dataset.templates
152
+ import copy
153
+ classnames = copy.deepcopy(classnames)
154
+ classifier = zeroshot_classifier(model, classnames, prompt_templates, args)
155
+
156
+ logging.info(f'Calculating image features for {zeroshot_dataset}')
157
+ results = {}
158
+ acc, features, labels = zeroshot_run(model, classifier, dataloader, args)
159
+ logging.info(f'{zeroshot_dataset} zero-shot accuracy: {acc}%')
160
+ results[f'{zeroshot_dataset}-zeroshot-acc'] = acc
161
+
162
+ for key, item in results.items():
163
+ results[key] = float(item)
164
+
165
+ return results
166
+
167
+
168
+ def zeroshot_accuracy(output, target, topk=(1,)):
169
+ pred = output.topk(max(topk), 1, True, True)[1].t()
170
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
171
+
172
+ return float(correct[0].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) * 100 / len(target)
173
+
174
+
175
+ def zeroshot_run(model, classifier, dataloader, args):
176
+ with torch.no_grad():
177
+ all_image_features = []
178
+ all_labels = []
179
+ all_logits = []
180
+ for images, target in tqdm.tqdm(dataloader, unit_scale=args.batch_size):
181
+ images = images.to(args.device)
182
+ image_features = model.encode_image(images)
183
+ image_features = F.normalize(image_features, dim=-1).detach().cpu()
184
+ logits = 100. * image_features @ classifier
185
+ all_image_features.append(image_features)
186
+ all_labels.append(target)
187
+ all_logits.append(logits)
188
+
189
+ all_image_features = torch.cat(all_image_features)
190
+ all_labels = torch.cat(all_labels)
191
+ all_logits = torch.cat(all_logits)
192
+
193
+ acc = zeroshot_accuracy(all_logits, all_labels, topk=(1,))
194
+ return round(acc, 2), all_image_features, all_labels
195
+
196
+
197
+ class CsvDataset(Dataset):
198
+ def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t", nori_dataset=False,
199
+ images_dir=''):
200
+ logging.debug(f'Loading csv data from {input_filename}.')
201
+ if 'rsicd' in input_filename:
202
+ df = pd.read_csv(input_filename, sep=sep, encoding='gb18030')
203
+ else:
204
+ df = pd.read_csv(input_filename, sep=sep)
205
+
206
+ self.nori_dataset = nori_dataset
207
+ self.f = None
208
+ self.images_dir = images_dir
209
+
210
+ self.images = df[img_key].tolist()
211
+ self.captions = df[caption_key].tolist()
212
+
213
+ self.transforms = transforms
214
+
215
+ self.duplicate()
216
+
217
+ logging.debug('Done loading data.')
218
+
219
+ def __len__(self):
220
+ return len(self.images)
221
+
222
+ def __getitem__(self, index):
223
+ texts = self.captions[index]
224
+ image = Image.open(os.path.join(self.images_dir, str(self.images[index])))
225
+ image = self.transforms(image)
226
+
227
+ return image, texts
228
+
229
+ def duplicate(self):
230
+ unique_images, indexs = np.unique(self.images, return_index=True)
231
+ if len(unique_images) != len(self.images):
232
+ logging.debug(
233
+ f'Amoung all {len(self.images)} images, there are only {len(unique_images)} unique images. Dupication will be performed to enable one-image-to-multiple-text retrieval.')
234
+ self.duplicated_images = []
235
+ self.duplicated_captions = []
236
+ for index in indexs:
237
+ self.duplicated_images.append(self.images[index])
238
+ same_indexs = [i for i, x in enumerate(self.images) if x == self.images[index]]
239
+ captions = []
240
+ for same_index in same_indexs:
241
+ captions.append(self.captions[same_index])
242
+ self.duplicated_captions.append(captions)
243
+
244
+ self.images = self.duplicated_images
245
+ self.captions = self.duplicated_captions
246
+
247
+
248
+ def retrieval_evaluation(model, preprocess, args, recall_k_list=[1, 5, 10], dataset_name=None):
249
+ """
250
+ Modified from https://github.com/LAION-AI/CLIP_benchmark/blob/main/clip_benchmark/metrics/zeroshot_retrieval.py
251
+ Evaluate the model on the given dataset
252
+
253
+ Parameters
254
+ ----------
255
+
256
+ model: torch.nn,Module
257
+ CLIP-like model with `encode_image` and `encode_text`
258
+
259
+ dataloader: torch.utils.data.Dataloader
260
+ dataloader to use for evaluation
261
+
262
+ tokenizer:
263
+ text tokenizer, i.e. convert list of strings to torch.Tensor of integers
264
+
265
+ device: cpu/cuda
266
+ recall_k_list: list of int
267
+ recall@k k's to use
268
+
269
+ Returns
270
+ -------
271
+
272
+ dict of retrieval metrics
273
+ """
274
+
275
+ if dataset_name == "rsitmd":
276
+ dataset = CsvDataset(
277
+ input_filename=os.path.join(args.test_dataset_dir, "rsitmd", "rsitmd_test.csv"),
278
+ transforms=preprocess,
279
+ img_key="filename",
280
+ caption_key="title",
281
+ sep=",",
282
+ images_dir=os.path.join(args.test_dataset_dir, "rsitmd", "images")
283
+ )
284
+ elif dataset_name == "rsicd":
285
+ dataset = CsvDataset(
286
+ input_filename=os.path.join(args.test_dataset_dir, "rsicd", "rsicd_test.csv"),
287
+ transforms=preprocess,
288
+ img_key="filename",
289
+ caption_key="title",
290
+ sep=",",
291
+ images_dir=os.path.join(args.test_dataset_dir, "rsicd", "RSICD_images")
292
+ )
293
+
294
+ dataloader = DataLoader(
295
+ dataset,
296
+ batch_size=args.batch_size,
297
+ num_workers=args.workers,
298
+ collate_fn=get_dataset_collate_fn('mscoco_captions')
299
+ )
300
+ n_batches = len(dataloader)
301
+ tokenizer = open_clip.tokenize
302
+ # list of batch of images embedding
303
+ batch_images_emb_list = []
304
+ # list of batch of text embedding
305
+ batch_texts_emb_list = []
306
+ # for each text, we collect the corresponding image index, as each image can have multiple corresponding texts
307
+ texts_image_index = []
308
+ dataloader = dataloader_with_indices(dataloader)
309
+
310
+ for batch_images, batch_texts, inds in tqdm.tqdm(dataloader, total=n_batches):
311
+ batch_images = batch_images.to(args.device)
312
+ # store the index of image for each text
313
+ batch_texts_image_index = [ind for ind, texts in zip(inds, batch_texts) for text in texts]
314
+ # tokenize all texts in the batch
315
+ batch_texts = tokenizer([text for i, texts in enumerate(batch_texts) for text in texts]).to(args.device)
316
+
317
+ # compute the embedding of images and texts
318
+ with torch.no_grad():
319
+ batch_image_features = model.encode_image(batch_images)
320
+ batch_text_features = model.encode_text(batch_texts)
321
+ batch_images_emb = F.normalize(batch_image_features, dim=-1)
322
+ batch_texts_emb = F.normalize(batch_text_features, dim=-1)
323
+
324
+ batch_images_emb_list.append(batch_images_emb.cpu())
325
+ batch_texts_emb_list.append(batch_texts_emb.cpu())
326
+ texts_image_index.extend(batch_texts_image_index)
327
+
328
+ batch_size = len(batch_images_emb_list[0])
329
+
330
+ # concatenate all embeddings
331
+ images_emb = torch.cat(batch_images_emb_list)
332
+ texts_emb = torch.cat(batch_texts_emb_list)
333
+
334
+ # get the score for each text and image pair
335
+ scores = texts_emb @ images_emb.t()
336
+
337
+ # construct a the positive pair matrix, which tells whether each text-image pair is a positive or not
338
+ positive_pairs = torch.zeros_like(scores, dtype=bool)
339
+ positive_pairs[torch.arange(len(scores)), texts_image_index] = True
340
+ metrics = {}
341
+ for recall_k in recall_k_list:
342
+ '''
343
+ Note that recall_at_k computes **actual** recall i.e. nb_true_positive/nb_positives, where the number
344
+ of true positives, e.g. for text retrieval, is, for each image, the number of retrieved texts matching that image among the top-k.
345
+ Also, the number of positives are the total number of texts matching the image in the dataset, as we have a set of captions
346
+ for each image, that number will be greater than 1 for text retrieval.
347
+ However, image/text retrieval recall@k, the way it is done in CLIP-like papers, is a bit different.
348
+ recall@k, in CLIP-like papers, is, for each image, either 1 or 0. It is 1 if atleast one text matches the image among the top-k.
349
+ so we can easily compute that using the actual recall, by checking whether there is at least one true positive,
350
+ which would be the case if the recall is greater than 0. One we compute the recal for each image (or text), we average
351
+ it over the dataset.
352
+ '''
353
+ metrics[f"retrieval-image2text-R@{recall_k}-{dataset_name}"] = (batchify(recall_at_k, scores.T,
354
+ positive_pairs.T, batch_size,
355
+ args.device,
356
+ k=recall_k) > 0).float().mean().item() * 100
357
+
358
+ for recall_k in recall_k_list:
359
+ metrics[f"retrieval-text2image-R@{recall_k}-{dataset_name}"] = (batchify(recall_at_k, scores, positive_pairs,
360
+ batch_size, args.device,
361
+ k=recall_k) > 0).float().mean().item() * 100
362
+
363
+ metrics[f"retrieval-mean-recall-{dataset_name}"] = np.mean(list(metrics.values()))
364
+
365
+ for key, item in metrics.items():
366
+ metrics[key] = round(float(item), 2)
367
+ logging.info(f'{dataset_name} retrieval recall: {metrics}%')
368
+
369
+ return metrics
370
+
371
+
372
+ class SLM(object):
373
+
374
+ # **
375
+ # * Copyright @2022 AI, AIRCAS. (mails.ucas.ac.cn)
376
+ #
377
+ # @author yuanzhiqiang <[email protected]>
378
+ # 2022/03/08
379
+
380
+ def __init__(self):
381
+ # logging
382
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
383
+ self.logger = logging.getLogger()
384
+
385
+ # parameters
386
+ self.rsu_beta = 0.707
387
+ self.rsu_eps = 1e-7
388
+
389
+ self.ras_expand_factor = 1.5
390
+ self.ras_filter_times = 5
391
+ self.ras_scala_beta = 3
392
+
393
+ self.rda_eta = 0.5
394
+
395
+ self.rmi_wsu = 0.4
396
+ self.rmi_was = 0.35
397
+ self.rmi_wda = 0.25
398
+
399
+ # visual settings
400
+ self.visual_ras = False
401
+ self.src_addmap_path = None
402
+
403
+ # sum indicator
404
+ self.all_metrics = self._format_output_dict()
405
+
406
+ def _format_output_dict(self, *params):
407
+ """
408
+ format output dict
409
+ :param params: keys
410
+ :return: format dict
411
+ """
412
+ len_params = len(params)
413
+ if len_params == 0: init_param = [[] for i in range(4)]
414
+ elif len_params == 4: init_param = params
415
+ else: raise NotImplementedError
416
+
417
+ return {
418
+ "↑ Rsu [0 ~ 1]": init_param[0],
419
+ "↑ Rda [0 ~ 1]": init_param[1],
420
+ "↓ Ras [0 ~ 1]": init_param[2],
421
+ "↑ Rmi [0 ~ 1]": init_param[3]
422
+ }
423
+
424
+ def logging_acc(self, metrics_dict, prob_path = None, ave = False):
425
+ """
426
+ logging the metrics
427
+ :param metrics_dict: dict of metrics
428
+ :param prob_path: path
429
+ :return: 0
430
+ """
431
+
432
+ if not ave:
433
+ self.logger.info("Eval {}".format(prob_path))
434
+ else:
435
+ self.logger.info("+++++++++++++++Average++++++++++++++")
436
+
437
+ self.logger.info("+++++++ Calc the SLM METRICS +++++++")
438
+ for metric, value in metrics_dict.items():
439
+ self.logger.info("++++ {}:{:.4f} ++++".format(metric, value))
440
+ self.logger.info("++++++++++++++++++++++++++++++++++++\n")
441
+
442
+ def set_visual_options(self, visual_ras, src_addmap_path):
443
+ """
444
+ set visual options
445
+ :param visual_ras: flag
446
+ :param src_addmap_path: set src addmap path
447
+ """
448
+ self.visual_ras = visual_ras
449
+ self.src_addmap_path = src_addmap_path
450
+ return True
451
+
452
+ def read_gray_to_prob(self, probmap_path):
453
+ """
454
+ Read the prob maps, and trans to probility
455
+ :param probmap_path: probmap routh
456
+ :return: probability
457
+ """
458
+ gray_image = cv2.imread(probmap_path, cv2.IMREAD_GRAYSCALE)
459
+ prob = gray_image / 255.0
460
+ return prob
461
+
462
+ def generate_mask_by_points(self, prob, points_list):
463
+ """
464
+ Generate mask by regions
465
+ :param prob: probability
466
+ :param points_list: regions
467
+ :return: mask
468
+ """
469
+ H, W = prob.shape
470
+
471
+ mask = np.zeros((H, W))
472
+ points_list = [np.array(i, np.int32) for i in points_list]
473
+ # fill
474
+ cv2.fillPoly(mask, points_list, 1)
475
+ return mask
476
+
477
+ def _get_region_center_radius(self, region_point):
478
+ """
479
+ get the region center and radius
480
+ :param region_point: regions
481
+ :return: mid_x, mid_y, radius
482
+ """
483
+ mid_x = int(reduce(lambda x, y: x+y, np.array(region_point)[:, 0]) / len(region_point))
484
+ mid_y = int(reduce(lambda x, y: x+y, np.array(region_point)[:, 1]) / len(region_point))
485
+ radius = int(np.mean([np.linalg.norm(np.array(point) - np.array([mid_x, mid_y])) for point in region_point]) * self.ras_expand_factor)
486
+ return mid_x, mid_y, radius
487
+
488
+ def _get_prob_center_in_gray(self, prob):
489
+ """
490
+ get the top point with the highest probability from the probability map
491
+ :param prob: probability
492
+ :return: centers
493
+ """
494
+
495
+ # recover the prob
496
+ gray_img = np.asarray(prob * 255.0, dtype=np.uint8)
497
+ # cv2.imwrite("./gray_img.jpg", gray_img)
498
+ # construct continuous area
499
+ continuous_area = np.asarray(gray_img > 150, np.uint8) * 255
500
+ # cv2.imwrite("./continuous_area_img_0.jpg", continuous_area)
501
+ continuous_area = np.uint8(measure.label(continuous_area, connectivity=2))
502
+ # cv2.imwrite("./continuous_area_img_1.jpg", continuous_area)
503
+
504
+ # soften
505
+ for i in range(self.ras_filter_times):
506
+ gray_img = cv2.boxFilter(gray_img, ddepth=-1, ksize=(50, 50))
507
+
508
+ # get probability binary map
509
+ mx = maximum_filter(gray_img, size=1000)
510
+ gray_img = np.where(mx == gray_img, gray_img, 0)
511
+ # cv2.imwrite("./local_maxima_before_filter.jpg", gray_img)
512
+ gray_img = np.asarray(gray_img > 0, np.uint8) * 255
513
+ # cv2.imwrite("./local_maxima_after_filter.jpg", gray_img)
514
+
515
+ # get probability area information
516
+ labels = measure.label(gray_img, connectivity=2)
517
+ all_region_infos = measure.regionprops(labels)
518
+ centers = [[int(i) for i in prop.centroid][::-1] for prop in all_region_infos]
519
+
520
+ # construct v-center list and sort
521
+ v_center = [[c[0], c[1], prob[c[1]][c[0]]] for c in centers]
522
+ v_center.sort(key= lambda x: x[2], reverse=True)
523
+ centers = list(map(lambda x: x[:2], v_center))
524
+
525
+ # filter centers
526
+ centers = [i for i in centers if prob[i[1]][i[0]] >= 0.5]
527
+
528
+ return centers, continuous_area
529
+
530
+ def _get_offset_between_real_and_synthetic(self, real_center_radius, prob_centers, bina_img):
531
+ """
532
+ calculate true center offset from result center
533
+ :param real_center_radius: real_center_radius
534
+ :param prob_centers: prob_centers
535
+ :return: offsets
536
+ """
537
+
538
+ # check prob_centers is not None
539
+ if len(prob_centers) == 0 : return [real_center_radius[0][2]]
540
+
541
+ offsets = []
542
+ for center_radius in real_center_radius:
543
+ x, y, r = center_radius
544
+
545
+ # calc the l2 dis
546
+ dises = list(map(lambda p: np.linalg.norm(np.array([x, y] - np.array(p))), prob_centers))
547
+
548
+ # filter the dis in cicle
549
+ dises = list(filter(lambda d: d <= r, dises))
550
+
551
+ # if no prob center set it to radius
552
+ offsets.append(np.mean(dises) if len(dises) != 0 else r)
553
+
554
+ return offsets
555
+
556
+ def _trans_ras_offset_to_scalable_ras(self, offsets, centers_and_radius):
557
+ """
558
+ convert distance offset to ras value
559
+ :param offsets: offsets
560
+ :return: centers_and_radius
561
+ """
562
+
563
+ # granular transformation
564
+ granular_offet = np.mean([off/v[2] for off, v in zip(offsets, centers_and_radius)])
565
+
566
+ # scala transformation
567
+ granular_offet = (np.exp(self.ras_scala_beta * granular_offet) - 1) / (np.exp(self.ras_scala_beta) - 1)
568
+
569
+ return granular_offet
570
+
571
+ def ras(self, region_lists, prob, visual=True, src_img=None):
572
+ """
573
+ calc the matric of ras: makes attention center close to annotation center
574
+ :param region_lists: regions
575
+ :param prob: probability
576
+ :return: ras
577
+ """
578
+
579
+ # get the annotation center and radius
580
+ centers_and_radius = [self._get_region_center_radius(i) for i in region_lists]
581
+
582
+ # get the point with the highest probability from the probability map
583
+ prob_centers, bina_img = self._get_prob_center_in_gray(prob)
584
+
585
+ # calculate true center offset from result center
586
+ offsets = self._get_offset_between_real_and_synthetic(centers_and_radius, prob_centers, bina_img)
587
+
588
+ # convert distance offset to rcs value
589
+ ras = self._trans_ras_offset_to_scalable_ras(offsets, centers_and_radius)
590
+
591
+ # visual
592
+ if visual and (src_img != None):
593
+ src_img = cv2.imread(src_img)
594
+
595
+ # logging something
596
+ # print("centers_and_radius: ", centers_and_radius)
597
+ # print("prob_centers: ", prob_centers)
598
+ # print("offsets: ", offsets)
599
+
600
+ # backup area
601
+ for c_r in centers_and_radius:
602
+ cv2.circle(src_img, (c_r[0], c_r[1]), c_r[2], 2, 3)
603
+
604
+ # candidate points
605
+ for idx, point in enumerate(prob_centers):
606
+ cv2.circle(src_img, tuple(point), 6*(idx+1), 1, 4)
607
+ cv2.putText(src_img, str(idx+1), tuple(point), cv2.FONT_HERSHEY_COMPLEX, 6, (0, 0, 0), 25)
608
+
609
+ cv2.imwrite("./img_circle.jpg", src_img)
610
+
611
+ # print(prob_centers)
612
+
613
+ return ras
614
+
615
+ def rsu(self, prob, mask):
616
+ """
617
+ calc the salient area proportion
618
+ :param prob: probability
619
+ :param mask: mask
620
+ :return: rsu
621
+ """
622
+
623
+ all_mask_value = np.sum(np.multiply(prob, mask))
624
+ all_value = np.sum(prob)
625
+ H, W = np.shape(mask)
626
+ all_mask = np.sum(mask)
627
+
628
+ left_frac = all_mask_value / (all_value - all_mask_value + self.rsu_eps)
629
+
630
+ right_frac = (H * W - all_mask) / all_mask
631
+
632
+ rsu = -np.exp(-1 * self.rsu_beta * left_frac * right_frac) + 1
633
+
634
+ return rsu
635
+
636
+ def rda(self, region_lists, prob):
637
+ """
638
+ calc the matric of rda: makes attention center focus on one point
639
+ :param region_lists: regions
640
+ :param prob: probability
641
+ :return: rda
642
+ """
643
+
644
+ # get the annotation center and radius
645
+ centers_and_radius = [self._get_region_center_radius(i) for i in region_lists]
646
+
647
+ # get the point with the highest probability from the probability map
648
+ prob_centers, bina_img = self._get_prob_center_in_gray(prob)
649
+
650
+ # set value
651
+ rda = []
652
+ for c_r in centers_and_radius:
653
+ x, y, r = c_r
654
+
655
+ # calc the backup points
656
+ backup_points = list(filter(lambda p: np.linalg.norm(np.array([x, y] - np.array(p))) <= r, prob_centers))
657
+
658
+ # margin condition
659
+ len_backup_points = len(backup_points)
660
+ if len_backup_points <= 1 :
661
+ rda.append(float(len_backup_points))
662
+ continue
663
+
664
+ # if len_backup_points >= 2, calc the attention discrete
665
+ centers_attention = np.average(backup_points, axis=0)
666
+ dises = list(map(lambda p: np.linalg.norm(np.array(centers_attention - np.array(p))), backup_points))
667
+ meas_dis = np.mean(dises) / r
668
+
669
+ rda_single = 0.5 * (1 - meas_dis) + np.exp(- self.rda_eta * (len_backup_points + 2))
670
+
671
+ rda.append(rda_single)
672
+
673
+ return np.mean(rda)
674
+
675
+ def rmi(self, rsu, rda, ras):
676
+ """
677
+ calculate the mean indicator
678
+ :param rsu: rsu
679
+ :param rda: rda
680
+ :param ras: ras
681
+ :return: rmi
682
+ """
683
+ return self.rmi_wsu * rsu + self.rmi_was * (1 - ras) + self.rmi_wda * rda
684
+
685
+ def evaluate(self, prob_path, region_list):
686
+ """
687
+ evaluate the slm task
688
+ :param probmap_path: probability map path
689
+ :param region_list: region points
690
+ :return: slm metrics
691
+ """
692
+ # read prob
693
+ prob = self.read_gray_to_prob(prob_path)
694
+
695
+ # generate mask
696
+ mask = self.generate_mask_by_points(prob, region_list)
697
+ # import os
698
+ # cv2.imwrite(os.path.join(prob_path.rsplit("/", 1)[0], "maskbypt_0.jpg"), mask*255)
699
+ # rsu
700
+ rsu = self.rsu(prob, mask)
701
+
702
+ # ras
703
+ ras = self.ras(region_list, prob, visual=self.visual_ras, src_img=self.src_addmap_path)
704
+
705
+ # rda
706
+ rda = self.rda(region_list, prob)
707
+
708
+ # mi
709
+ rmi = self.rmi(rsu, rda, ras)
710
+
711
+ # sort metrics
712
+ metrics = self._format_output_dict(rsu, rda, ras, rmi)
713
+ # self.logging_acc(metrics, prob_path)
714
+
715
+ return metrics
716
+
717
+ def append_metric(self, metric):
718
+ """
719
+ append metric to calc ave indicator
720
+ :param metric: sort metrics
721
+ """
722
+ for k in metric.keys():
723
+ self.all_metrics[k].append(metric[k])
724
+
725
+ def get_the_mean_metric(self):
726
+ """
727
+ get the mean metric
728
+ """
729
+ mean_metric = {}
730
+ for k in self.all_metrics:
731
+ mean_metric[k] = np.mean(self.all_metrics[k])
732
+
733
+ self.logging_acc(mean_metric, ave=True)
734
+ return mean_metric
735
+
736
+
737
+ def semantic_localization_evaluation(model, selo_dataset, preprocess, args):
738
+ assert selo_dataset == 'AIR-SLT'
739
+
740
+ def collect_fn_selo(batch):
741
+ assert len(batch) == 1
742
+ source_img, subimages, text, points, subimg_name_list = batch[0]
743
+ return source_img, subimages, text, points, subimg_name_list
744
+
745
+ dataset = get_selo_dataset(
746
+ root=args.test_dataset_dir, transform=preprocess, identifier=None
747
+ )
748
+
749
+ dataloader = torch.utils.data.DataLoader(
750
+ dataset,
751
+ batch_size=1,
752
+ shuffle=False,
753
+ num_workers=0,
754
+ collate_fn=collect_fn_selo
755
+ )
756
+ tokenizer = open_clip.tokenize
757
+ logger = dataset.logger
758
+ slm_metric = SLM()
759
+
760
+ with torch.no_grad():
761
+ for idx, sample in tqdm.tqdm(enumerate(dataloader)):
762
+ source_img, subimages, text, points, subimg_name_list = sample
763
+ subimages = subimages.to(args.device)
764
+ text = tokenizer(text).to(args.device)
765
+ text_features = model.encode_text(text)
766
+ text_features /= text_features.norm(dim=-1, keepdim=True)
767
+
768
+ sim_results = []
769
+ for subimage in subimages:
770
+ subimage = subimage.unsqueeze(0)
771
+ sub_img_feat = model.encode_image(subimage)
772
+ sub_img_feat /= sub_img_feat.norm(dim=-1, keepdim=True)
773
+ similarity = (sub_img_feat * text_features).sum().detach().cpu().numpy()
774
+ sim_results.append(similarity)
775
+
776
+ # print("Start generate heatmap ...")
777
+ img_row = np.shape(source_img)[0]
778
+ img_col = np.shape(source_img)[1]
779
+
780
+ # mkdir map
781
+ heat_map = np.zeros([img_row, img_col], dtype=float)
782
+ heat_num = np.zeros([img_row, img_col], dtype=float)
783
+ for idx, file in enumerate(subimg_name_list):
784
+ r_start, r_end, c_start, c_end = file.replace(".jpg", "").split("_")
785
+ heat_map[int(r_start):int(r_end), int(c_start):int(c_end)] += sim_results[idx]
786
+ heat_num[int(r_start):int(r_end), int(c_start):int(c_end)] += 1
787
+
788
+ for i in range(np.shape(heat_map)[0]):
789
+ for j in range(np.shape(heat_map)[1]):
790
+ heat_map[i, j] = heat_map[i, j] / heat_num[i, j]
791
+
792
+ # logger.info("Generation finished, start operating blur, colormap, etc. ...")
793
+ # filter
794
+ adaptive = np.asarray(heat_map)
795
+ adaptive = adaptive - np.min(adaptive)
796
+ probmap = adaptive / np.max(adaptive)
797
+ # must convert to type unit8
798
+ probmap = np.uint8(255 * probmap)
799
+ probmap = cv2.medianBlur(probmap, 251)
800
+ heatmap = cv2.applyColorMap(probmap, cv2.COLORMAP_JET)
801
+ img_add = cv2.addWeighted(source_img, 0.7, heatmap, 0.3, 0)
802
+
803
+ probmap_path = os.path.join(dataset.cache_path, "probmap_{}.jpg".format(idx))
804
+ heatmap_path = os.path.join(dataset.cache_path, "heatmap_{}.jpg".format(idx))
805
+ addmap_path = os.path.join(dataset.cache_path, "addmap_{}.jpg".format(idx))
806
+
807
+ # logger.info("Saving heatmap in {} ...".format(heatmap_path))
808
+ # logger.info("Saving probmap in {} ...".format(probmap_path))
809
+ # logger.info("Saving addmap in {} ...".format(addmap_path))
810
+
811
+ cv2.imwrite(probmap_path, probmap)
812
+ cv2.imwrite(heatmap_path, heatmap)
813
+ cv2.imwrite(addmap_path, img_add)
814
+ # logger.info("Saved ok.")
815
+
816
+ metrics = slm_metric.evaluate(probmap_path, region_list=points)
817
+ slm_metric.append_metric(metrics)
818
+
819
+ mean_metric = slm_metric.get_the_mean_metric()
820
+
821
+ results = {}
822
+ logging.info(f'{selo_dataset} selo metrics: {mean_metric}')
823
+
824
+ for key, item in mean_metric.items():
825
+ results[key] = float(item)
826
+
827
+ return results
828
+
829
+
830
+ class AIR_SLT(Dataset):
831
+ # Ref: https://github.com/xiaoyuan1996/SemanticLocalizationMetrics/blob/master/predict/generate_selo.py
832
+ def __init__(self, root, subimage_transform, identifier):
833
+ super().__init__()
834
+ self.json_path = os.path.join(root, "annotations", "anno.json")
835
+ # self.cache_path = os.path.join(root, "selo_cache_{}_{}".format(identifier, str(datetime.now()).replace(" ", "-").replace(":", "-").replace(".", "-")))
836
+ self.cache_path = os.path.join(root, "selo_cache")
837
+ os.makedirs(self.cache_path, exist_ok=True)
838
+ with open(self.json_path, 'r', encoding='utf8') as fp:
839
+ self.json_data = json.load(fp)
840
+ self.img_root = os.path.join(root, "imgs")
841
+ self.subimage_transform = subimage_transform
842
+ self.logger = get_logger(os.path.join(self.cache_path, 'log.txt'))
843
+ self.step = "256_512_768"
844
+
845
+ def __len__(self):
846
+ return len(self.json_data)
847
+
848
+ def __getitem__(self, index):
849
+ item = self.json_data[index]
850
+ img_name = item['jpg_name']
851
+ text = item['caption']
852
+ points = item['points']
853
+ steps = [int(step) for step in self.step.split("_")]
854
+ img_path = os.path.join(self.img_root, img_name)
855
+
856
+ # logging
857
+ # self.logger.info("Processing {}/{}: {}".format(index, len(self.json_data), img_name))
858
+ # self.logger.info("Corresponding text: {}".format(text))
859
+
860
+ # processing
861
+ self.split_image(img_path, steps)
862
+ with torch.no_grad():
863
+ subimages_dir = os.path.join(self.cache_path, os.path.basename(img_path).split(".")[0]) + '_subimages'
864
+ subimages = os.listdir(subimages_dir)
865
+
866
+ img = cv2.imread(img_path)
867
+ subimg_list = []
868
+ subimg_name_list = []
869
+ for subimage_name in subimages:
870
+ subimage_path = os.path.join(subimages_dir, subimage_name)
871
+ subimg = Image.open(subimage_path)
872
+ subimg = self.subimage_transform(subimg).unsqueeze(0)
873
+ subimg_list.append(subimg)
874
+ subimg_name_list.append(subimage_name)
875
+ subimgs = torch.vstack(subimg_list)
876
+ return img, subimgs, [text], points, subimg_name_list
877
+
878
+ def split_image(self, img_path, steps):
879
+ subimage_files_dir = os.path.join(self.cache_path, os.path.basename(img_path).split(".")[0])
880
+
881
+ # 裁切图像文件夹
882
+ subimages_dir = subimage_files_dir + '_subimages'
883
+ if os.path.exists(subimages_dir):
884
+ delete_dire(subimages_dir)
885
+ else:
886
+ os.makedirs(subimages_dir)
887
+
888
+ # Read Image
889
+ source_img = cv2.imread(img_path)
890
+ img_weight = np.shape(source_img)[0]
891
+ img_height = np.shape(source_img)[1]
892
+ # self.logger.info("img size:{}x{}".format(img_weight, img_height))
893
+
894
+ for step in steps:
895
+ # self.logger.info("Start split images with step {}".format(step))
896
+ for gap in [step, 0.5 * step]:
897
+ gap = int(gap)
898
+
899
+ # Cut img
900
+ for h in range(0 + (step - gap), img_height, step):
901
+ h_start, h_end = h, h + step
902
+ # bound?
903
+ if h_end >= img_height:
904
+ h_start, h_end = img_height - step, img_height
905
+
906
+ for w in range(0 + (step - gap), img_weight, step):
907
+ w_start, w_end = w, w + step
908
+ # bound?
909
+ if w_end >= img_weight:
910
+ w_start, w_end = img_weight - step, img_weight
911
+
912
+ cut_img_name = str(w_start) + "_" + str(w_end) + "_" + str(h_start) + "_" + str(h_end) + ".jpg"
913
+ cut_img = source_img[w_start:w_end, h_start:h_end]
914
+ cut_img = cv2.resize(cut_img, (256, 256), interpolation=cv2.INTER_CUBIC)
915
+
916
+ cv2.imwrite(os.path.join(subimages_dir, cut_img_name), cut_img)
917
+
918
+ # self.logger.info("Image {} has been split successfully.".format(img_path))
919
+
920
+
921
+ def delete_dire(dire):
922
+ dir_list = []
923
+ for root, dirs, files in os.walk(dire):
924
+ for afile in files:
925
+ os.remove(os.path.join(root, afile))
926
+ for adir in dirs:
927
+ dir_list.append(os.path.join(root, adir))
928
+ for bdir in dir_list:
929
+ os.rmdir(bdir)
930
+
931
+
932
+ # logger
933
+ def get_logger(save_path=None):
934
+ logger = logging.getLogger()
935
+ logger.setLevel(logging.INFO) # 设置打印级别
936
+ formatter = logging.Formatter('%(asctime)s %(message)s')
937
+
938
+ # 设置屏幕打印的格式
939
+ sh = logging.StreamHandler()
940
+ sh.setFormatter(formatter)
941
+ logger.addHandler(sh)
942
+
943
+ # 设置log保存
944
+ if save_path != None:
945
+ fh = logging.FileHandler(save_path, encoding='utf8')
946
+ fh.setFormatter(formatter)
947
+ logger.addHandler(fh)
948
+
949
+ return logger
950
+
951
+
952
+ def get_selo_dataset(root, transform, identifier):
953
+
954
+ AIR_SLT_root = os.path.join(root, "AIR-SLT")
955
+ dataset = AIR_SLT(
956
+ root=AIR_SLT_root,
957
+ subimage_transform=transform,
958
+ identifier=identifier
959
+ )
960
+
961
+ return dataset