yuxi-liu-wired commited on
Commit
b65a332
1 Parent(s): ae78120
CSD/CSD/loss_utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms as transforms
2
+ import torchvision.transforms.functional as F
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ np.random.seed(0)
7
+
8
+
9
+ class GaussianBlur(object):
10
+ """blur a single image on CPU"""
11
+ def __init__(self, kernel_size):
12
+ radias = kernel_size // 2
13
+ kernel_size = radias * 2 + 1
14
+ self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
15
+ stride=1, padding=0, bias=False, groups=3)
16
+ self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
17
+ stride=1, padding=0, bias=False, groups=3)
18
+ self.k = kernel_size
19
+ self.r = radias
20
+
21
+ self.blur = nn.Sequential(
22
+ nn.ReflectionPad2d(radias),
23
+ self.blur_h,
24
+ self.blur_v
25
+ )
26
+
27
+ self.pil_to_tensor = transforms.ToTensor()
28
+ self.tensor_to_pil = transforms.ToPILImage()
29
+
30
+ def __call__(self, img):
31
+ img = self.pil_to_tensor(img).unsqueeze(0)
32
+
33
+ sigma = np.random.uniform(0.1, 2.0)
34
+ x = np.arange(-self.r, self.r + 1)
35
+ x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
36
+ x = x / x.sum()
37
+ x = torch.from_numpy(x).view(1, -1).repeat(3, 1)
38
+
39
+ self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
40
+ self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))
41
+
42
+ with torch.no_grad():
43
+ img = self.blur(img)
44
+ img = img.squeeze()
45
+
46
+ img = self.tensor_to_pil(img)
47
+
48
+ return
49
+
50
+
51
+ s=1
52
+ size = 224
53
+
54
+ normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
55
+
56
+ transforms_branch0 = transforms.Compose([
57
+ transforms.Resize(size=size, interpolation=F.InterpolationMode.BICUBIC),
58
+ transforms.CenterCrop(size),
59
+ transforms.ToTensor(),
60
+ normalize,
61
+ ])
62
+
63
+ transforms_branch1 = transforms.Compose([
64
+ transforms.RandomResizedCrop(size, interpolation=F.InterpolationMode.BICUBIC),
65
+ transforms.RandomHorizontalFlip(),
66
+ transforms.RandomVerticalFlip(p=0.3),
67
+ transforms.RandomRotation(degrees=np.random.choice([0,90,180,270])),
68
+ transforms.ToTensor(),
69
+ normalize,
70
+ ])
71
+
72
+ color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
73
+ transforms_branch2 = transforms.Compose([
74
+ # transforms.RandomResizedCrop(size=size, interpolation=F.InterpolationMode.BICUBIC),
75
+ transforms.Resize(size=size, interpolation=F.InterpolationMode.BICUBIC),
76
+ transforms.CenterCrop(size),
77
+ transforms.RandomHorizontalFlip(),
78
+ transforms.RandomApply([transforms.ColorJitter(brightness=0.5, contrast=0.5,
79
+ saturation=0.5,hue=0.1)
80
+ ], p=0.6),
81
+ transforms.RandomApply([transforms.RandomInvert(),transforms.RandomGrayscale(), transforms.GaussianBlur(kernel_size=(5,5), sigma=(0.1, 4))], p=0.8),
82
+ # GaussianBlur(kernel_size=int(0.1 * size)),
83
+ transforms.ToTensor(),
84
+ normalize
85
+ ])
86
+
87
+
88
+ class ContrastiveTransformations(object):
89
+
90
+ def __init__(self, transforms_b0, transforms_b1,transforms_b2):
91
+ self.transforms_b0 = transforms_b0
92
+ self.transforms_b1 = transforms_b1
93
+ self.transforms_b2 = transforms_b2
94
+
95
+ def __call__(self, x):
96
+ return [self.transforms_b0(x), self.transforms_b1(x), self.transforms_b2(x)]
CSD/CSD/losses.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Yonglong Tian ([email protected])
3
+ Date: May 07, 2020
4
+ Code from https://github.com/HobbitLong/SupContrast/blob/master/losses.py
5
+ """
6
+ from __future__ import print_function
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ class SupConLoss(nn.Module):
13
+ """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
14
+ It also supports the unsupervised contrastive loss in SimCLR"""
15
+ def __init__(self, temperature=0.07, contrast_mode='all',
16
+ base_temperature=1.0):
17
+ super(SupConLoss, self).__init__()
18
+ self.temperature = temperature
19
+ self.contrast_mode = contrast_mode
20
+ self.base_temperature = base_temperature
21
+
22
+ def forward(self, features, labels=None, mask=None):
23
+ """Compute loss for model. If both `labels` and `mask` are None,
24
+ it degenerates to SimCLR unsupervised loss:
25
+ https://arxiv.org/pdf/2002.05709.pdf
26
+
27
+ Args:
28
+ features: hidden vector of shape [bsz, n_views, ...].
29
+ labels: ground truth of shape [bsz].
30
+ mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
31
+ has the same class as sample i. Can be asymmetric.
32
+ Returns:
33
+ A loss scalar.
34
+ """
35
+ device = (torch.device('cuda')
36
+ if features.is_cuda
37
+ else torch.device('cpu'))
38
+
39
+ if len(features.shape) < 3:
40
+ raise ValueError('`features` needs to be [bsz, n_views, ...],'
41
+ 'at least 3 dimensions are required')
42
+ if len(features.shape) > 3:
43
+ features = features.view(features.shape[0], features.shape[1], -1)
44
+
45
+ batch_size = features.shape[0]
46
+ if labels is not None and mask is not None:
47
+ raise ValueError('Cannot define both `labels` and `mask`')
48
+ elif labels is None and mask is None:
49
+ mask = torch.eye(batch_size, dtype=torch.float32).to(device)
50
+ elif labels is not None:
51
+ labels = labels.contiguous().view(-1, 1)
52
+ if labels.shape[0] != batch_size:
53
+ raise ValueError('Num of labels does not match num of features')
54
+ mask = torch.eq(labels, labels.T).float().to(device)
55
+ else:
56
+ mask = mask.float().to(device)
57
+
58
+ contrast_count = features.shape[1]
59
+ contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
60
+ if self.contrast_mode == 'one':
61
+ anchor_feature = features[:, 0]
62
+ anchor_count = 1
63
+ elif self.contrast_mode == 'all':
64
+ anchor_feature = contrast_feature
65
+ anchor_count = contrast_count
66
+ else:
67
+ raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
68
+
69
+ anchor_dot_contrast = torch.div(
70
+ torch.matmul(anchor_feature, contrast_feature.T),
71
+ self.temperature)
72
+ # for numerical stability
73
+ logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
74
+ logits = anchor_dot_contrast - logits_max.detach()
75
+
76
+ # tile mask
77
+ mask = mask.repeat(anchor_count, contrast_count)
78
+ # mask-out self-contrast cases
79
+ logits_mask = torch.scatter(
80
+ torch.ones_like(mask),
81
+ 1,
82
+ torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
83
+ 0
84
+ )
85
+ mask = mask * logits_mask
86
+
87
+ # compute log_prob
88
+ exp_logits = torch.exp(logits) * logits_mask
89
+ log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6) # NOTE: modified based on https://github.com/HobbitLong/SupContrast/issues/104
90
+
91
+ # compute mean of log-likelihood over positive, adding a small value in case mask row is 0
92
+ mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-6)
93
+
94
+ # loss
95
+ loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
96
+ loss = loss.view(anchor_count, batch_size).mean()
97
+
98
+ return loss
99
+
CSD/CSD/model.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import clip
4
+ import copy
5
+ from torch.autograd import Function
6
+
7
+
8
+ from .utils import convert_weights_float
9
+
10
+
11
+ class ReverseLayerF(Function):
12
+
13
+ @staticmethod
14
+ def forward(ctx, x, alpha):
15
+ ctx.alpha = alpha
16
+
17
+ return x.view_as(x)
18
+
19
+ @staticmethod
20
+ def backward(ctx, grad_output):
21
+ output = grad_output.neg() * ctx.alpha
22
+
23
+ return output, None
24
+
25
+
26
+ ## taken from https://github.com/moein-shariatnia/OpenAI-CLIP/blob/master/modules.py
27
+ class ProjectionHead(nn.Module):
28
+ def __init__(
29
+ self,
30
+ embedding_dim,
31
+ projection_dim,
32
+ dropout=0
33
+ ):
34
+ super().__init__()
35
+ self.projection = nn.Linear(embedding_dim, projection_dim)
36
+ self.gelu = nn.GELU()
37
+ self.fc = nn.Linear(projection_dim, projection_dim)
38
+ self.dropout = nn.Dropout(dropout)
39
+ self.layer_norm = nn.LayerNorm(projection_dim)
40
+
41
+ def forward(self, x):
42
+ projected = self.projection(x)
43
+ x = self.gelu(projected)
44
+ x = self.fc(x)
45
+ x = self.dropout(x)
46
+ x = x + projected
47
+ x = self.layer_norm(x)
48
+ return x
49
+
50
+
51
+ def init_weights(m): # TODO: do we need init for layernorm?
52
+ if isinstance(m, nn.Linear):
53
+ torch.nn.init.xavier_uniform_(m.weight)
54
+ if m.bias is not None:
55
+ nn.init.normal_(m.bias, std=1e-6)
56
+
57
+
58
+ class CSD_CLIP(nn.Module):
59
+ """backbone + projection head"""
60
+ def __init__(self, name='vit_large',content_proj_head='default'):
61
+ super(CSD_CLIP, self).__init__()
62
+ self.content_proj_head = content_proj_head
63
+ if name == 'vit_large':
64
+ clipmodel, _ = clip.load("ViT-L/14")
65
+ self.backbone = clipmodel.visual
66
+ self.embedding_dim = 1024
67
+ elif name == 'vit_base':
68
+ clipmodel, _ = clip.load("ViT-B/16")
69
+ self.backbone = clipmodel.visual
70
+ self.embedding_dim = 768
71
+ self.feat_dim = 512
72
+ else:
73
+ raise Exception('This model is not implemented')
74
+
75
+ convert_weights_float(self.backbone)
76
+ self.last_layer_style = copy.deepcopy(self.backbone.proj)
77
+ if content_proj_head == 'custom':
78
+ self.last_layer_content = ProjectionHead(self.embedding_dim,self.feat_dim)
79
+ self.last_layer_content.apply(init_weights)
80
+
81
+ else:
82
+ self.last_layer_content = copy.deepcopy(self.backbone.proj)
83
+
84
+ self.backbone.proj = None
85
+
86
+ @property
87
+ def dtype(self):
88
+ return self.backbone.conv1.weight.dtype
89
+
90
+ def forward(self, input_data, alpha=None):
91
+
92
+ feature = self.backbone(input_data)
93
+
94
+ if alpha is not None:
95
+ reverse_feature = ReverseLayerF.apply(feature, alpha)
96
+ else:
97
+ reverse_feature = feature
98
+
99
+ style_output = feature @ self.last_layer_style
100
+ style_output = nn.functional.normalize(style_output, dim=1, p=2)
101
+
102
+ # if alpha is not None:
103
+ if self.content_proj_head == 'custom':
104
+ content_output = self.last_layer_content(reverse_feature)
105
+ else:
106
+ content_output = reverse_feature @ self.last_layer_content
107
+ content_output = nn.functional.normalize(content_output, dim=1, p=2)
108
+ return feature, content_output, style_output
CSD/CSD/train_csd.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import argparse
4
+ import json
5
+ import math
6
+ import os
7
+ import pathlib
8
+ import sys
9
+ import time
10
+ import datetime
11
+ import numpy as np
12
+ import copy
13
+ import torch
14
+ import torch.backends.cudnn as cudnn
15
+ import torch.nn as nn
16
+ import torch.nn.parallel
17
+ import torch.optim
18
+ import torch.utils.data
19
+ import torch.utils.data.distributed
20
+ from pathlib import Path
21
+
22
+ sys.path.insert(0, str(pathlib.Path(__file__).parent.resolve()))
23
+
24
+ from CSD import utils
25
+ from data.wikiart import WikiArtTrain
26
+ from data.laion import LAION, LAIONDedup
27
+ from CSD.loss_utils import ContrastiveTransformations, transforms_branch0, transforms_branch1, transforms_branch2
28
+ from CSD.model import CSD_CLIP
29
+ from CSD.losses import SupConLoss
30
+
31
+
32
+ def get_args_parser():
33
+
34
+ parser = argparse.ArgumentParser('CSD', add_help=False)
35
+
36
+ # Model
37
+ parser.add_argument("-a","--arch",default='vit_base', type=str)
38
+
39
+ # Data
40
+ parser.add_argument('--train_set', default='wikiart', # 'wikiart' or 'laion'
41
+ help='Wiki art data path')
42
+ parser.add_argument('--train_path', required=True,
43
+ help='Wiki art data path')
44
+ parser.add_argument('--train_anno_path',
45
+ default='-projects/diffusion_rep/data/laion_style_subset',
46
+ help='Annotation dir, used only for LAION')
47
+ parser.add_argument("--min_images_per_label", default=1, type=int,
48
+ help="minimum images for a label (used only for laion)")
49
+ parser.add_argument("--max_images_per_label", default=100000, type=int,
50
+ help="minimum images for a label (used only for laion)")
51
+
52
+ parser.add_argument('--eval_set', default='wikiart', # 'domainnet' or 'wikiart'
53
+ help='Wiki art data path')
54
+ parser.add_argument('--eval_path',required=True,
55
+ help='Path to query dataset.')
56
+ parser.add_argument("--maxsize", default=512, type=int,
57
+ help="maximum size of the val dataset to be used")
58
+
59
+ # Optimization
60
+ parser.add_argument( "--use_fp16", action="store_true",
61
+ help="use fp16")
62
+ parser.add_argument( "--use_distributed_loss", action="store_true",
63
+ help="use distributed loss")
64
+ parser.add_argument('--clip_grad', type=float, default=3.0,
65
+ help="""Maximal parameter gradient norm if using
66
+ gradient clipping. Clipping with norm .3 ~ 1.0 can
67
+ help optimization for larger ViT architectures.
68
+ 0 for disabling.""")
69
+ parser.add_argument("--iters", default=100000, type=int, # default: eval only
70
+ help="number of total iterations to run")
71
+ parser.add_argument("-b", "--batch_size_per_gpu", default=64, type=int,
72
+ help="batch size per GPU (default: 64)")
73
+ parser.add_argument("--lr", "--learning_rate", default=0.003, type=float,
74
+ help="learning rate", dest="lr",)
75
+ parser.add_argument("--lr_bb", "--learning_rate_bb", default=0.0001, type=float,
76
+ help="learning rat for backbone", dest="lr_bb",)
77
+ parser.add_argument("--wd", "--weight_decay", default=1e-4, type=float,
78
+ help="weight decay (default: 1e-4)", dest="weight_decay")
79
+ parser.add_argument("--warmup_iters", default=30000, type=int,
80
+ help="Number of iterations for the linear learning-rate warm up.")
81
+ parser.add_argument('--min_lr', type=float, default=1e-6, help="""Target LR at the
82
+ end of optimization. We use a cosine LR schedule with linear warmup.""")
83
+ parser.add_argument('--lr_scheduler_type', type=str, default='constant_with_warmup')
84
+ parser.add_argument('--freeze_last_layer', default=0, type=int,
85
+ help="""Number of iterations during which we keep the
86
+ output layer fixed. Typically doing so during
87
+ first few iters helps training. Try increasing this
88
+ value if the loss does not decrease.""")
89
+
90
+ parser.add_argument('--content_proj_head', type=str, default='default')
91
+ parser.add_argument('--lambda_s', default=1, type=float, help='Weighting on style loss')
92
+ parser.add_argument('--lambda_c', default=0, type=float, help='Weighting on content loss')
93
+ parser.add_argument('--lam_sup', default=5, type=float, help='Supervised style loss lambda')
94
+ parser.add_argument('--temp', default=0.1, type=float, help='contrastive temperature')
95
+
96
+ parser.add_argument('--clamp_content_loss', default=None, type=float, help='Clipping the content loss')
97
+ parser.add_argument( "--non_adv_train", action="store_true",
98
+ help="dont train content adversarially, use neg of content loss")
99
+ parser.add_argument('--eval_embed', type=str, default='head', help='which embeddings to use in evaluation')
100
+ parser.add_argument('--style_loss_type', type=str, default='SupCon', help='which loss function for style loss computation')
101
+ # Logging Params
102
+ parser.add_argument('--output_dir', required=True, type=str, help='Path to save logs and checkpoints.')
103
+ parser.add_argument('--print_freq', default=100, type=int, help='Print the logs every x iterations.')
104
+ parser.add_argument('--saveckp_freq', default=5000, type=int, help='Save checkpoint every x iterations.')
105
+ parser.add_argument('--eval_freq', default=5000, type=int, help='Eval the model every x iterations.')
106
+ parser.add_argument('--eval_k', type=int, nargs='+', default=[1, 5, 100], help='eval map and recall at these k values.')
107
+
108
+ # Misc
109
+ parser.add_argument("--resume_if_available", action="store_true")
110
+ parser.add_argument("--seed", default=42, type=int,
111
+ help="seed for initializing training. ")
112
+ parser.add_argument("-j", "--workers", default=4, type=int,
113
+ help="number of data loading workers (default: 32)")
114
+ parser.add_argument("--rank", default=-1, type=int,
115
+ help="node rank for distributed training")
116
+ parser.add_argument("--dist_url", default="env://",
117
+ help="url used to set up distributed training")
118
+ parser.add_argument("--local_rank", default=0, type=int,
119
+ help="Please ignore and do not set this argument.")
120
+ return parser
121
+
122
+
123
+ def sample_infinite_data(loader, seed=0):
124
+ rng = torch.Generator()
125
+ rng.manual_seed(seed)
126
+ BIG_NUMBER = 9999999999999
127
+ while True:
128
+ # Randomize dataloader indices before every epoch:
129
+ try: # Only relevant for distributed sampler:
130
+ shuffle_seed = torch.randint(0, BIG_NUMBER, (1,), generator=rng).item()
131
+ loader.sampler.set_epoch(shuffle_seed)
132
+ except AttributeError:
133
+ pass
134
+ for batch in loader:
135
+ yield batch
136
+
137
+
138
+ def main():
139
+ parser = argparse.ArgumentParser('CSD', parents=[get_args_parser()])
140
+ args = parser.parse_args()
141
+
142
+ if args.non_adv_train:
143
+ assert args.clamp_content_loss is not None, 'You have to clamp content loss in non-adv style of training'
144
+ utils.init_distributed_mode(args)
145
+ if args.seed is not None:
146
+ utils.fix_random_seeds(args.seed)
147
+
148
+ print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
149
+ cudnn.benchmark = True
150
+
151
+ # ======================= setup logging =======================
152
+ if utils.is_main_process() and args.iters > 0:
153
+ os.makedirs(args.output_dir, exist_ok=True)
154
+
155
+ # ======================= preparing data =======================
156
+ if args.lambda_c < 1e-3:
157
+ train_transforms = ContrastiveTransformations(transforms_branch1, transforms_branch1, transforms_branch2)
158
+ else:
159
+ train_transforms = ContrastiveTransformations(transforms_branch0, transforms_branch1, transforms_branch2)
160
+
161
+ if args.train_set == 'wikiart':
162
+ train_dataset = WikiArtTrain(
163
+ args.train_path, 'database',
164
+ transform=train_transforms)
165
+ elif args.train_set == 'laion':
166
+ train_dataset = LAION(
167
+ args.train_path, args.train_anno_path,
168
+ min_images_per_label=args.min_images_per_label,
169
+ max_images_per_label=args.max_images_per_label,
170
+ transform=train_transforms)
171
+ elif args.train_set == 'laion_dedup':
172
+ train_dataset = LAIONDedup(
173
+ args.train_path, args.train_anno_path,
174
+ transform=train_transforms)
175
+ else:
176
+ raise NotImplementedError
177
+
178
+ train_sampler = torch.utils.data.DistributedSampler(train_dataset, shuffle=True)
179
+ train_loader = torch.utils.data.DataLoader(
180
+ train_dataset, batch_size=args.batch_size_per_gpu, drop_last=True,
181
+ num_workers=args.workers, pin_memory=True, sampler=train_sampler)
182
+ train_loader = sample_infinite_data(train_loader, args.seed)
183
+
184
+ if args.eval_set == 'wikiart':
185
+ vq_dataset = WikiArtTrain(
186
+ args.eval_path, 'query', transform=transforms_branch0, maxsize=args.maxsize)
187
+ vidx_dataset = WikiArtTrain(
188
+ args.eval_path, 'database', transform=transforms_branch0, maxsize=8*args.maxsize)
189
+
190
+ vq_loader = torch.utils.data.DataLoader(
191
+ vq_dataset, batch_size=2*args.batch_size_per_gpu, drop_last=True,
192
+ num_workers=min(args.workers, 2), pin_memory=True, shuffle=False)
193
+ vidx_loader = torch.utils.data.DataLoader(
194
+ vidx_dataset, batch_size=2*args.batch_size_per_gpu, drop_last=True,
195
+ num_workers=min(args.workers, 2), pin_memory=True, shuffle=False)
196
+ print(f"Data loaded: there are {len(train_dataset)} train images.")
197
+ print(f"Data loaded: there are {len(vq_dataset)} query and {len(vidx_dataset)} index images.")
198
+
199
+ # ======================= building model =======================
200
+ model = CSD_CLIP(args.arch, args.content_proj_head) # TODO: projection dim into hyperparam
201
+ model = model.cuda()
202
+ # synchronize batch norms (if any)
203
+ if utils.has_batchnorms(model):
204
+ model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
205
+
206
+ if args.distributed:
207
+ model = nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
208
+ model_without_ddp = model.module
209
+ else:
210
+ model_without_ddp = model
211
+
212
+ print(f"Model built with {args.arch} network.")
213
+
214
+ # ======================= setup loss and optimizers =======================
215
+ loss_content = SupConLoss(temperature=args.temp) # TODO: Do we want 2 diff
216
+ loss_style = SupConLoss(temperature=args.temp)
217
+
218
+ params_groups = utils.get_params_groups(model_without_ddp.backbone)
219
+ # lr is set by scheduler
220
+ opt_bb = torch.optim.SGD(
221
+ params_groups, lr=0, momentum=0.9, weight_decay=args.weight_decay)
222
+
223
+ if args.content_proj_head != 'default':
224
+ opt_proj = torch.optim.SGD(
225
+ [{'params': model_without_ddp.last_layer_style},
226
+ {'params': model_without_ddp.last_layer_content.parameters()},],
227
+ # [model_without_ddp.last_layer_style, *model_without_ddp.last_layer_content.parameters()],
228
+ lr=0, momentum=0.9, weight_decay=0, # we do not apply weight decay
229
+ )
230
+ else:
231
+ opt_proj = torch.optim.SGD(
232
+ [model_without_ddp.last_layer_style, model_without_ddp.last_layer_content],
233
+ lr=0, momentum=0.9, weight_decay=0, # we do not apply weight decay
234
+ )
235
+
236
+ fp16_scaler = None
237
+ if args.use_fp16:
238
+ fp16_scaler = torch.cuda.amp.GradScaler()
239
+
240
+ # ======================= init schedulers =======================
241
+ if args.lr_scheduler_type =='cosine':
242
+ lr_schedule_bb = utils.cosine_scheduler(
243
+ args.lr_bb * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
244
+ min(args.min_lr, args.lr_bb),
245
+ max(args.iters, 1), warmup_iters=min(args.warmup_iters, args.iters)
246
+ )
247
+
248
+ lr_schedule_proj = utils.cosine_scheduler(
249
+ args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
250
+ min(args.min_lr, args.lr),
251
+ max(args.iters, 1), warmup_iters=min(args.warmup_iters, args.iters)
252
+ )
253
+ elif args.lr_scheduler_type =='constant_with_warmup':
254
+ lr_schedule_bb = utils.constant_with_warmup_scheduler(
255
+ args.lr_bb * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
256
+ max(args.iters, 1), warmup_iters=min(args.warmup_iters, args.iters),
257
+ )
258
+
259
+ lr_schedule_proj = utils.constant_with_warmup_scheduler(
260
+ args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
261
+ max(args.iters, 1), warmup_iters=min(args.warmup_iters, args.iters),
262
+ )
263
+ else:
264
+ print('Using constant LR for training')
265
+ lr_schedule_bb = utils.constant_with_warmup_scheduler(
266
+ args.lr_bb * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
267
+ max(args.iters, 1), warmup_iters=0,
268
+ )
269
+
270
+ lr_schedule_proj = utils.constant_with_warmup_scheduler(
271
+ args.lr * (args.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
272
+ max(args.iters, 1), warmup_iters=0,
273
+ )
274
+
275
+ print(f"Loss, optimizer and schedulers ready.")
276
+
277
+ # ======================= optionally resume training =======================
278
+ to_restore = {"iter": 0}
279
+ if args.resume_if_available:
280
+ if not args.output_dir.endswith(".pth"):
281
+ ckpt_path = os.path.join(args.output_dir, "checkpoint.pth")
282
+ else:
283
+ ckpt_path = args.output_dir
284
+ utils.restart_from_checkpoint(
285
+ ckpt_path,
286
+ run_variables=to_restore,
287
+ model_state_dict=model,
288
+ opt_bb=opt_bb,
289
+ opt_proj=opt_proj,
290
+ fp16_scaler=fp16_scaler,
291
+ )
292
+ print(f"Start iter: {to_restore['iter']}")
293
+ start_iter = to_restore["iter"]
294
+ save_dict = None
295
+ print("Running eval before training!")
296
+ val_stats = evaluate(model, vq_loader, vidx_loader, fp16_scaler is not None, args.eval_k, args.eval_embed)
297
+ if start_iter >= args.iters:
298
+ print(f"Start iter {start_iter} >= Max iters {args.iters} training!")
299
+ return
300
+
301
+ start_time = time.time()
302
+ print("Starting CSD training !")
303
+ metric_logger = utils.MetricLogger(delimiter=" ", max_len=args.iters)
304
+ header = 'Iter:'
305
+
306
+ #TODO: Check if we need to set model to train mode
307
+ model.eval()
308
+ for iter, batch in enumerate(metric_logger.log_every(train_loader, 100, header)):
309
+ # ======================= training =======================
310
+
311
+ if iter < start_iter:
312
+ continue
313
+
314
+ if iter >= args.iters:
315
+ break
316
+
317
+ # update learning rates according to their schedule
318
+ # it = len(train_loader) * epoch + it # global training iteration
319
+ p = float(iter) / args.iters
320
+
321
+ for param_group in opt_bb.param_groups:
322
+ param_group["lr"] = lr_schedule_bb[iter]
323
+
324
+ for param_group in opt_proj.param_groups:
325
+ param_group["lr"] = lr_schedule_proj[iter]
326
+ if args.non_adv_train:
327
+ alpha = None
328
+ else:
329
+ alpha = 2. / (1. + np.exp(-10 * p)) - 1
330
+ images, artists, *_ = batch
331
+ if args.lambda_c < 1e-3:
332
+ images = torch.cat([images[0],images[1]], dim=0)
333
+ else:
334
+ images = torch.cat(images, dim=0)
335
+
336
+ # import torchvision
337
+ # torchvision.utils.save_image(images,'./temp.png')
338
+ images= images.cuda(non_blocking=True)
339
+ artists = artists.cuda(non_blocking=True).float()
340
+
341
+ with torch.cuda.amp.autocast(fp16_scaler is not None):
342
+ _ , content_output, style_output = model(images, alpha)
343
+
344
+ # Normalize the output features for each image
345
+ content_output = nn.functional.normalize(content_output, dim=1, p=2)
346
+ style_output = nn.functional.normalize(style_output, dim=1, p=2)
347
+
348
+ # Split the output features for each image and its views
349
+ style_output = utils.split_reshape(style_output, args.batch_size_per_gpu, [0, 1])
350
+ content_output = utils.split_reshape(content_output, args.batch_size_per_gpu, [0, -1])
351
+
352
+ # Gather tensors from all GPUs
353
+ if args.use_distributed_loss:
354
+ style_output = torch.cat(utils.GatherLayer.apply(style_output), dim=0)
355
+ content_output = torch.cat(utils.GatherLayer.apply(content_output), dim=0)
356
+
357
+ # Compute content loss (SimCLR loss, doesn't use labels)
358
+ loss_c = loss_content(content_output)
359
+ if args.clamp_content_loss is not None:
360
+ loss_c = loss_c.clamp(max = args.clamp_content_loss)
361
+ if args.non_adv_train:
362
+ loss_c = -1 * loss_c
363
+
364
+ # Compute style loss
365
+ if args.use_distributed_loss:
366
+ artists = torch.cat(utils.GatherLayer.apply(artists), dim=0)
367
+
368
+ label_mask = artists @ artists.t()
369
+ if args.style_loss_type == 'SimClr':
370
+ loss_s_ssl = loss_style(style_output)
371
+ loss_s_sup = torch.Tensor([0]).to(model.device)
372
+ elif args.style_loss_type == 'OnlySup':
373
+ loss_s_ssl = torch.Tensor([0]).to(model.device)
374
+ loss_s_sup = loss_style(style_output[:, 0:1, :], mask=label_mask)
375
+ else:
376
+ loss_s_sup = loss_style(style_output[:, 0:1, :], mask=label_mask)
377
+ loss_s_ssl = loss_style(style_output)
378
+
379
+ loss_s = args.lam_sup*loss_s_sup + loss_s_ssl
380
+
381
+ loss = args.lambda_c * loss_c + args.lambda_s * loss_s
382
+
383
+ if not math.isfinite(loss.item()):
384
+ print("Loss is {}, stopping training".format(loss.item()))
385
+ sys.exit(1)
386
+
387
+ opt_bb.zero_grad()
388
+ opt_proj.zero_grad()
389
+ param_norms = None
390
+ if fp16_scaler is None:
391
+ loss.backward()
392
+ if args.clip_grad:
393
+ param_norms = utils.clip_gradients(model, args.clip_grad)
394
+ utils.cancel_gradients_last_layer(iter, model, args.freeze_last_layer)
395
+ opt_bb.step()
396
+ opt_proj.step()
397
+ else:
398
+ fp16_scaler.scale(loss).backward()
399
+ if args.clip_grad:
400
+ fp16_scaler.unscale_(opt_bb) # unscale the gradients of optimizer's assigned params in-place
401
+ fp16_scaler.unscale_(opt_proj)
402
+ param_norms = utils.clip_gradients(model, args.clip_grad)
403
+ utils.cancel_gradients_last_layer(iter, model, args.freeze_last_layer)
404
+ fp16_scaler.step(opt_bb)
405
+ fp16_scaler.step(opt_proj)
406
+ fp16_scaler.update()
407
+
408
+ # logging
409
+ torch.cuda.synchronize()
410
+ metric_logger.update(loss=loss.item())
411
+ metric_logger.update(content_loss=loss_c.item())
412
+ metric_logger.update(style_loss=loss_s.item())
413
+ metric_logger.update(style_loss_sup=loss_s_sup.item())
414
+ metric_logger.update(style_loss_ssl=loss_s_ssl.item())
415
+ metric_logger.update(lr_bb=opt_bb.param_groups[0]["lr"])
416
+ # metric_logger.update(wd_bb=opt_bb.param_groups[0]["weight_decay"])
417
+ metric_logger.update(lr_proj=opt_proj.param_groups[0]["lr"])
418
+ # metric_logger.update(wd_proj=opt_proj.param_groups[0]["weight_decay"])
419
+
420
+ # ============ writing logs ... ============
421
+ save_dict = {
422
+ 'model_state_dict': model.state_dict(),
423
+ 'opt_bb': opt_bb.state_dict(),
424
+ 'opt_proj': opt_proj.state_dict(),
425
+ 'iter': iter+1,
426
+ 'args': args,
427
+ }
428
+ if fp16_scaler is not None:
429
+ save_dict['fp16_scaler'] = fp16_scaler.state_dict()
430
+
431
+ if (iter+1) % args.saveckp_freq == 0:
432
+ utils.save_on_master(save_dict, os.path.join(args.output_dir, 'checkpoint.pth'))
433
+ utils.save_on_master(save_dict, os.path.join(args.output_dir, f'checkpoint{iter+1:08}.pth'))
434
+
435
+ train_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
436
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
437
+ 'iter': iter+1}
438
+
439
+ if utils.is_main_process() and (iter+1) % args.print_freq == 0:
440
+ with (Path(args.output_dir) / "log.txt").open("a") as f:
441
+ f.write(json.dumps(log_stats) + "\n")
442
+
443
+ # Eval
444
+ if (iter+1) % args.eval_freq==0:
445
+ # gather the stats from all processes
446
+ metric_logger.synchronize_between_processes()
447
+ print("Averaged stats:", metric_logger)
448
+
449
+ val_stats = evaluate(model, vq_loader, vidx_loader, fp16_scaler is not None, args.eval_k, args.eval_embed)
450
+
451
+ if args.iters > 0 and save_dict is not None:
452
+ utils.save_on_master(save_dict, os.path.join(args.output_dir, 'checkpoint.pth'))
453
+
454
+ total_time = time.time() - start_time
455
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
456
+ print('Training time {}'.format(total_time_str))
457
+
458
+
459
+ def evaluate(model, vq_loader, vidx_loader, use_fp16=False, eval_k=[1, 5, 100], eval_embed='head'):
460
+ metric_logger = utils.MetricLogger(delimiter=" ")
461
+ # Valid loader is the query set
462
+ # Train loader is the search set
463
+ use_cuda = True
464
+ db_features = utils.extract_features(model, vidx_loader,use_cuda, use_fp16, eval_embed)
465
+ q_features = utils.extract_features(model, vq_loader, use_cuda, use_fp16, eval_embed)
466
+
467
+ # Aggregate style features across GPUs
468
+ if utils.get_rank() != 0:
469
+ return
470
+
471
+ # Find the nearest neighbor indices for each query
472
+ similarities = q_features @ db_features.T
473
+ similarities = torch.argsort(similarities, dim=1, descending=True).cpu()
474
+
475
+ # Map neighbor indices to labels (assuming one hot labels)
476
+ q_labels = vq_loader.dataset.labels
477
+ db_labels = vidx_loader.dataset.labels
478
+ gts = q_labels @ db_labels.T
479
+ #TODO: vectorize this
480
+ preds = np.array([gts[i][similarities[i]] for i in range(len(gts))])
481
+
482
+ # Compute metrics
483
+ for topk in eval_k:
484
+ mode_recall = utils.Metrics.get_recall_bin(copy.deepcopy(preds), topk)
485
+ mode_mrr = utils.Metrics.get_mrr_bin(copy.deepcopy(preds), topk)
486
+ mode_map = utils.Metrics.get_map_bin(copy.deepcopy(preds), topk)
487
+ # print(f'Recall@{topk}: {mode_recall:.2f}, mAP@{topk}: {mode_map:.2f}')
488
+ metric_logger.update(**{f'recall@{topk}': mode_recall, f'mAP@{topk}': mode_map, f'MRR@{topk}': mode_mrr})
489
+
490
+ # gather the stats from all processes
491
+ print("Averaged stats:", metric_logger)
492
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
493
+
494
+
495
+ if __name__ == "__main__":
496
+ main()
CSD/CSD/utils.py ADDED
@@ -0,0 +1,853 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Misc functions.
16
+
17
+ Mostly copy-paste from torchvision references or other public repos like DETR:
18
+ https://github.com/facebookresearch/detr/blob/master/util/misc.py
19
+ """
20
+ import os
21
+ import sys
22
+ import time
23
+ import math
24
+ import random
25
+ import datetime
26
+ import subprocess
27
+ from collections import defaultdict, deque, OrderedDict
28
+
29
+ import numpy as np
30
+ import torch
31
+ from torch import nn
32
+ import torch.distributed as dist
33
+ import warnings
34
+ import argparse
35
+ from PIL import ImageFilter, ImageOps
36
+
37
+
38
+ class GaussianBlur(object):
39
+ """
40
+ Apply Gaussian Blur to the PIL image.
41
+ """
42
+
43
+ def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
44
+ self.prob = p
45
+ self.radius_min = radius_min
46
+ self.radius_max = radius_max
47
+
48
+ def __call__(self, img):
49
+ do_it = random.random() <= self.prob
50
+ if not do_it:
51
+ return img
52
+
53
+ return img.filter(
54
+ ImageFilter.GaussianBlur(
55
+ radius=random.uniform(self.radius_min, self.radius_max)
56
+ )
57
+ )
58
+
59
+
60
+ class Solarization(object):
61
+ """
62
+ Apply Solarization to the PIL image.
63
+ """
64
+
65
+ def __init__(self, p):
66
+ self.p = p
67
+
68
+ def __call__(self, img):
69
+ if random.random() < self.p:
70
+ return ImageOps.solarize(img)
71
+ else:
72
+ return img
73
+
74
+
75
+ def clip_gradients(model, clip):
76
+ norms = []
77
+ for name, p in model.named_parameters():
78
+ if p.grad is not None:
79
+ param_norm = p.grad.data.norm(2)
80
+ norms.append(param_norm.item())
81
+ clip_coef = clip / (param_norm + 1e-6)
82
+ if clip_coef < 1:
83
+ p.grad.data.mul_(clip_coef)
84
+ return norms
85
+
86
+
87
+ def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
88
+ """
89
+ Re-start from checkpoint
90
+ """
91
+ if not os.path.isfile(ckp_path):
92
+ return
93
+ print("Found checkpoint at {}".format(ckp_path))
94
+
95
+ # open checkpoint file
96
+ checkpoint = torch.load(ckp_path, map_location="cpu")
97
+
98
+ # key is what to look for in the checkpoint file
99
+ # value is the object to load
100
+ # example: {'state_dict': model}
101
+ for key, value in kwargs.items():
102
+ if key in checkpoint and value is not None:
103
+ try:
104
+ msg = value.load_state_dict(checkpoint[key], strict=False)
105
+ print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
106
+ except TypeError:
107
+ try:
108
+ msg = value.load_state_dict(checkpoint[key])
109
+ print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path))
110
+ except ValueError:
111
+ print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckp_path))
112
+ else:
113
+ print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))
114
+
115
+ # re load variable important for the run
116
+ if run_variables is not None:
117
+ for var_name in run_variables:
118
+ if var_name in checkpoint:
119
+ run_variables[var_name] = checkpoint[var_name]
120
+
121
+
122
+ def convert_state_dict(state_dict):
123
+ new_state_dict = OrderedDict()
124
+ for k, v in state_dict.items():
125
+ if k.startswith("module."):
126
+ k = k.replace("module.", "")
127
+ new_state_dict[k] = v
128
+ return new_state_dict
129
+
130
+
131
+ def cosine_scheduler(base_value, final_value, iters, warmup_iters, start_warmup_value=0):
132
+ warmup_schedule = np.array([])
133
+
134
+ if warmup_iters > 0:
135
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
136
+
137
+ post_warmup_iters = np.arange(iters - warmup_iters)
138
+ schedule = final_value + 0.5 * (base_value - final_value) * (
139
+ 1 + np.cos(np.pi * post_warmup_iters / len(post_warmup_iters)))
140
+
141
+ schedule = np.concatenate((warmup_schedule, schedule))
142
+ assert len(schedule) == iters
143
+ return schedule
144
+
145
+
146
+ def constant_with_warmup_scheduler(base_value, iters, warmup_iters=0, start_warmup_value=0):
147
+ warmup_schedule = np.array([])
148
+
149
+ if warmup_iters > 0:
150
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
151
+
152
+ num_iters = iters - warmup_iters
153
+ schedule = np.array([base_value] * num_iters)
154
+
155
+ schedule = np.concatenate((warmup_schedule, schedule))
156
+ assert len(schedule) == iters
157
+ return schedule
158
+
159
+
160
+ def bool_flag(s):
161
+ """
162
+ Parse boolean arguments from the command line.
163
+ """
164
+ FALSY_STRINGS = {"off", "false", "0"}
165
+ TRUTHY_STRINGS = {"on", "true", "1"}
166
+ if s.lower() in FALSY_STRINGS:
167
+ return False
168
+ elif s.lower() in TRUTHY_STRINGS:
169
+ return True
170
+ else:
171
+ raise argparse.ArgumentTypeError("invalid value for a boolean flag")
172
+
173
+
174
+ def fix_random_seeds(seed=31):
175
+ """
176
+ Fix random seeds.
177
+ """
178
+ torch.manual_seed(seed)
179
+ torch.cuda.manual_seed_all(seed)
180
+ np.random.seed(seed)
181
+
182
+
183
+ class SmoothedValue(object):
184
+ """Track a series of values and provide access to smoothed values over a
185
+ window or the global series average.
186
+ """
187
+
188
+ def __init__(self, window_size=20, fmt=None):
189
+ if fmt is None:
190
+ fmt = "{median:.6f} ({global_avg:.6f})"
191
+ self.deque = deque(maxlen=window_size)
192
+ self.total = 0.0
193
+ self.count = 0
194
+ self.fmt = fmt
195
+
196
+ def update(self, value, n=1):
197
+ self.deque.append(value)
198
+ self.count += n
199
+ self.total += value * n
200
+
201
+ def synchronize_between_processes(self):
202
+ """
203
+ Warning: does not synchronize the deque!
204
+ """
205
+ if not is_dist_avail_and_initialized():
206
+ return
207
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
208
+ dist.barrier()
209
+ dist.all_reduce(t)
210
+ t = t.tolist()
211
+ self.count = int(t[0])
212
+ self.total = t[1]
213
+
214
+ @property
215
+ def median(self):
216
+ d = torch.tensor(list(self.deque))
217
+ return d.median().item()
218
+
219
+ @property
220
+ def avg(self):
221
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
222
+ return d.mean().item()
223
+
224
+ @property
225
+ def global_avg(self):
226
+ return self.total / self.count
227
+
228
+ @property
229
+ def max(self):
230
+ return max(self.deque)
231
+
232
+ @property
233
+ def value(self):
234
+ return self.deque[-1]
235
+
236
+ def __str__(self):
237
+ return self.fmt.format(
238
+ median=self.median,
239
+ avg=self.avg,
240
+ global_avg=self.global_avg,
241
+ max=self.max,
242
+ value=self.value)
243
+
244
+
245
+ def reduce_dict(input_dict, average=True):
246
+ """
247
+ Args:
248
+ input_dict (dict): all the values will be reduced
249
+ average (bool): whether to do average or sum
250
+ Reduce the values in the dictionary from all processes so that all processes
251
+ have the averaged results. Returns a dict with the same fields as
252
+ input_dict, after reduction.
253
+ """
254
+ world_size = get_world_size()
255
+ if world_size < 2:
256
+ return input_dict
257
+ with torch.no_grad():
258
+ names = []
259
+ values = []
260
+ # sort the keys so that they are consistent across processes
261
+ for k in sorted(input_dict.keys()):
262
+ names.append(k)
263
+ values.append(input_dict[k])
264
+ values = torch.stack(values, dim=0)
265
+ dist.all_reduce(values)
266
+ if average:
267
+ values /= world_size
268
+ reduced_dict = {k: v for k, v in zip(names, values)}
269
+ return reduced_dict
270
+
271
+
272
+ class MetricLogger(object):
273
+ def __init__(self, delimiter="\t", max_len=100):
274
+ self.meters = defaultdict(SmoothedValue)
275
+ self.delimiter = delimiter
276
+ self.max_len = max_len
277
+
278
+ def update(self, **kwargs):
279
+ for k, v in kwargs.items():
280
+ if isinstance(v, torch.Tensor):
281
+ v = v.item()
282
+ assert isinstance(v, (float, int))
283
+ self.meters[k].update(v)
284
+
285
+ def __getattr__(self, attr):
286
+ if attr in self.meters:
287
+ return self.meters[attr]
288
+ if attr in self.__dict__:
289
+ return self.__dict__[attr]
290
+ raise AttributeError("'{}' object has no attribute '{}'".format(
291
+ type(self).__name__, attr))
292
+
293
+ def __str__(self):
294
+ loss_str = []
295
+ for name, meter in self.meters.items():
296
+ loss_str.append(
297
+ "{}: {}".format(name, str(meter))
298
+ )
299
+ return self.delimiter.join(loss_str)
300
+
301
+ def synchronize_between_processes(self):
302
+ for meter in self.meters.values():
303
+ meter.synchronize_between_processes()
304
+
305
+ def add_meter(self, name, meter):
306
+ self.meters[name] = meter
307
+
308
+ def log_every(self, iterable, print_freq, header=None):
309
+ i = 0
310
+ if not header:
311
+ header = ''
312
+ start_time = time.time()
313
+ end = time.time()
314
+ iter_time = SmoothedValue(fmt='{avg:.6f}')
315
+ data_time = SmoothedValue(fmt='{avg:.6f}')
316
+ space_fmt = ':' + str(len(str(self.max_len))) + 'd'
317
+ if torch.cuda.is_available():
318
+ log_msg = self.delimiter.join([
319
+ header,
320
+ '[{0' + space_fmt + '}/{1}]',
321
+ 'eta: {eta}',
322
+ '{meters}',
323
+ 'time: {time}',
324
+ 'data: {data}',
325
+ 'max mem: {memory:.0f}'
326
+ ])
327
+ else:
328
+ log_msg = self.delimiter.join([
329
+ header,
330
+ '[{0' + space_fmt + '}/{1}]',
331
+ 'eta: {eta}',
332
+ '{meters}',
333
+ 'time: {time}',
334
+ 'data: {data}'
335
+ ])
336
+ MB = 1024.0 * 1024.0
337
+ for obj in iterable:
338
+ data_time.update(time.time() - end)
339
+ yield obj
340
+ iter_time.update(time.time() - end)
341
+ if i % print_freq == 0 or i == self.max_len - 1:
342
+ eta_seconds = iter_time.global_avg * (self.max_len - i)
343
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
344
+ if torch.cuda.is_available():
345
+ print(log_msg.format(
346
+ i, self.max_len,
347
+ eta=eta_string,
348
+ meters=str(self),
349
+ time=str(iter_time), data=str(data_time),
350
+ memory=torch.cuda.max_memory_allocated() / MB))
351
+ else:
352
+ print(log_msg.format(
353
+ i, self.max_len,
354
+ eta=eta_string,
355
+ meters=str(self),
356
+ time=str(iter_time), data=str(data_time)))
357
+ i += 1
358
+ end = time.time()
359
+ total_time = time.time() - start_time
360
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
361
+ print('{} Total time: {} ({:.6f} s / it)'.format(
362
+ header, total_time_str, total_time / self.max_len))
363
+
364
+
365
+ def get_sha():
366
+ cwd = os.path.dirname(os.path.abspath(__file__))
367
+
368
+ def _run(command):
369
+ return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
370
+
371
+ sha = 'N/A'
372
+ diff = "clean"
373
+ branch = 'N/A'
374
+ try:
375
+ sha = _run(['git', 'rev-parse', 'HEAD'])
376
+ subprocess.check_output(['git', 'diff'], cwd=cwd)
377
+ diff = _run(['git', 'diff-index', 'HEAD'])
378
+ diff = "has uncommited changes" if diff else "clean"
379
+ branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
380
+ except Exception:
381
+ pass
382
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
383
+ return message
384
+
385
+
386
+ def is_dist_avail_and_initialized():
387
+ if not dist.is_available():
388
+ return False
389
+ if not dist.is_initialized():
390
+ return False
391
+ return True
392
+
393
+
394
+ def get_world_size():
395
+ if not is_dist_avail_and_initialized():
396
+ return 1
397
+ return dist.get_world_size()
398
+
399
+
400
+ def get_rank():
401
+ if not is_dist_avail_and_initialized():
402
+ return 0
403
+ return dist.get_rank()
404
+
405
+
406
+ def is_main_process():
407
+ return get_rank() == 0
408
+
409
+
410
+ def save_on_master(*args, **kwargs):
411
+ if is_main_process():
412
+ torch.save(*args, **kwargs)
413
+
414
+
415
+ def setup_for_distributed(is_master):
416
+ """
417
+ This function disables printing when not in master process
418
+ """
419
+ import builtins as __builtin__
420
+ builtin_print = __builtin__.print
421
+
422
+ def print(*args, **kwargs):
423
+ force = kwargs.pop('force', False)
424
+ if is_master or force:
425
+ builtin_print(*args, **kwargs)
426
+
427
+ __builtin__.print = print
428
+
429
+
430
+ def init_distributed_mode(args):
431
+ # launched with torch.distributed.launch
432
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
433
+ args.rank = int(os.environ["RANK"])
434
+ args.world_size = int(os.environ['WORLD_SIZE'])
435
+ args.gpu = int(os.environ['LOCAL_RANK'])
436
+ # launched with submitit on a slurm cluster
437
+ elif 'SLURM_PROCID' in os.environ:
438
+ args.rank = int(os.environ['SLURM_PROCID'])
439
+ args.gpu = args.rank % torch.cuda.device_count()
440
+ # launched naively with `python main_dino.py`
441
+ # we manually add MASTER_ADDR and MASTER_PORT to env variables
442
+ elif torch.cuda.is_available():
443
+ print('Will run the code on one GPU.')
444
+ args.rank, args.gpu, args.world_size = 0, 0, 1
445
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
446
+ os.environ['MASTER_PORT'] = '29500'
447
+ else:
448
+ print('Does not support training without GPU.')
449
+ sys.exit(1)
450
+
451
+ if torch.cuda.device_count() > 0:
452
+ args.distributed = True
453
+ else:
454
+ args.distributed = False
455
+
456
+ dist.init_process_group(
457
+ backend="nccl",
458
+ init_method=args.dist_url,
459
+ world_size=args.world_size,
460
+ rank=args.rank,
461
+ )
462
+
463
+ torch.cuda.set_device(args.gpu)
464
+ print('| distributed init (rank {}): {}'.format(
465
+ args.rank, args.dist_url), flush=True)
466
+ dist.barrier()
467
+ setup_for_distributed(args.rank == 0)
468
+
469
+
470
+ def accuracy(output, target, topk=(1,)):
471
+ """Computes the accuracy over the k top predictions for the specified values of k"""
472
+ maxk = max(topk)
473
+ batch_size = target.size(0)
474
+ _, pred = output.topk(maxk, 1, True, True)
475
+ pred = pred.t()
476
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
477
+ return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
478
+
479
+
480
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
481
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
482
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
483
+ def norm_cdf(x):
484
+ # Computes standard normal cumulative distribution function
485
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
486
+
487
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
488
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
489
+ "The distribution of values may be incorrect.",
490
+ stacklevel=2)
491
+
492
+ with torch.no_grad():
493
+ # Values are generated by using a truncated uniform distribution and
494
+ # then using the inverse CDF for the normal distribution.
495
+ # Get upper and lower cdf values
496
+ l = norm_cdf((a - mean) / std)
497
+ u = norm_cdf((b - mean) / std)
498
+
499
+ # Uniformly fill tensor with values from [l, u], then translate to
500
+ # [2l-1, 2u-1].
501
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
502
+
503
+ # Use inverse cdf transform for normal distribution to get truncated
504
+ # standard normal
505
+ tensor.erfinv_()
506
+
507
+ # Transform to proper mean, std
508
+ tensor.mul_(std * math.sqrt(2.))
509
+ tensor.add_(mean)
510
+
511
+ # Clamp to ensure it's in the proper range
512
+ tensor.clamp_(min=a, max=b)
513
+ return tensor
514
+
515
+
516
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
517
+ # type: (Tensor, float, float, float, float) -> Tensor
518
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
519
+
520
+
521
+ class LARS(torch.optim.Optimizer):
522
+ """
523
+ Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
524
+ """
525
+
526
+ def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001,
527
+ weight_decay_filter=None, lars_adaptation_filter=None):
528
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
529
+ eta=eta, weight_decay_filter=weight_decay_filter,
530
+ lars_adaptation_filter=lars_adaptation_filter)
531
+ super().__init__(params, defaults)
532
+
533
+ @torch.no_grad()
534
+ def step(self):
535
+ for g in self.param_groups:
536
+ for p in g['params']:
537
+ dp = p.grad
538
+
539
+ if dp is None:
540
+ continue
541
+
542
+ if p.ndim != 1:
543
+ dp = dp.add(p, alpha=g['weight_decay'])
544
+
545
+ if p.ndim != 1:
546
+ param_norm = torch.norm(p)
547
+ update_norm = torch.norm(dp)
548
+ one = torch.ones_like(param_norm)
549
+ q = torch.where(param_norm > 0.,
550
+ torch.where(update_norm > 0,
551
+ (g['eta'] * param_norm / update_norm), one), one)
552
+ dp = dp.mul(q)
553
+
554
+ param_state = self.state[p]
555
+ if 'mu' not in param_state:
556
+ param_state['mu'] = torch.zeros_like(p)
557
+ mu = param_state['mu']
558
+ mu.mul_(g['momentum']).add_(dp)
559
+
560
+ p.add_(mu, alpha=-g['lr'])
561
+
562
+
563
+ class MultiCropWrapper(nn.Module):
564
+ """
565
+ Perform forward pass separately on each resolution input.
566
+ The inputs corresponding to a single resolution are clubbed and single
567
+ forward is run on the same resolution inputs. Hence we do several
568
+ forward passes = number of different resolutions used. We then
569
+ concatenate all the output features and run the head forward on these
570
+ concatenated features.
571
+ """
572
+
573
+ def __init__(self, backbone, head):
574
+ super(MultiCropWrapper, self).__init__()
575
+ # disable layers dedicated to ImageNet labels classification
576
+ backbone.fc, backbone.head = nn.Identity(), nn.Identity()
577
+ self.backbone = backbone
578
+ self.head = head
579
+
580
+ def forward(self, x):
581
+ # convert to list
582
+ if not isinstance(x, list):
583
+ x = [x]
584
+ idx_crops = torch.cumsum(torch.unique_consecutive(
585
+ torch.tensor([inp.shape[-1] for inp in x]),
586
+ return_counts=True,
587
+ )[1], 0)
588
+ start_idx, output = 0, torch.empty(0).to(x[0].device)
589
+ for end_idx in idx_crops:
590
+ _out = self.backbone(torch.cat(x[start_idx: end_idx]))
591
+ # The output is a tuple with XCiT model. See:
592
+ # https://github.com/facebookresearch/xcit/blob/master/xcit.py#L404-L405
593
+ if isinstance(_out, tuple):
594
+ _out = _out[0]
595
+ # accumulate outputs
596
+ output = torch.cat((output, _out))
597
+ start_idx = end_idx
598
+ # Run the head forward on the concatenated features.
599
+ return self.head(output)
600
+
601
+
602
+ def get_params_groups(model):
603
+ regularized = []
604
+ not_regularized = []
605
+ for name, param in model.named_parameters():
606
+ if not param.requires_grad:
607
+ continue
608
+ # we do not regularize biases nor Norm parameters
609
+ if name.endswith(".bias") or len(param.shape) == 1:
610
+ not_regularized.append(param)
611
+ else:
612
+ regularized.append(param)
613
+ return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
614
+
615
+
616
+ def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
617
+ if epoch >= freeze_last_layer:
618
+ return
619
+ for n, p in model.named_parameters():
620
+ if "last_layer" in n:
621
+ p.grad = None
622
+
623
+
624
+ def has_batchnorms(model):
625
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
626
+ for name, module in model.named_modules():
627
+ if isinstance(module, bn_types):
628
+ return True
629
+ return False
630
+
631
+
632
+ #####
633
+ def convert_weights_float(model: nn.Module):
634
+ """Convert applicable model parameters to fp32"""
635
+
636
+ def _convert_weights_to_fp32(l):
637
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
638
+ l.weight.data = l.weight.data.float()
639
+ if l.bias is not None:
640
+ l.bias.data = l.bias.data.float()
641
+
642
+ if isinstance(l, nn.MultiheadAttention):
643
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
644
+ tensor = getattr(l, attr)
645
+ if tensor is not None:
646
+ tensor.data = tensor.data.float()
647
+
648
+ for name in ["text_projection", "proj"]:
649
+ if hasattr(l, name):
650
+ attr = getattr(l, name)
651
+ if attr is not None:
652
+ attr.data = attr.data.float()
653
+
654
+ model.apply(_convert_weights_to_fp32)
655
+
656
+
657
+ def split_reshape(x, bs, combination=None):
658
+ n = len(x) // bs
659
+ assert n in [2, 3], "The num augs should be 2 or 3 in number"
660
+ f = torch.split(x, [bs] * n, dim=0)
661
+ if combination is None:
662
+ x_reshape = torch.cat([f[i].unsqueeze(1) for i in range(n)], dim=1)
663
+ else:
664
+ x_reshape = torch.cat([f[i].unsqueeze(1) for i in combination], dim=1)
665
+
666
+ # if repeatcase:
667
+ # x_reshape = torch.cat([f1.unsqueeze(1), f1.unsqueeze(1)], dim=1)
668
+ return x_reshape
669
+
670
+
671
+ class AverageMeter(object):
672
+ """Computes and stores the average and current value"""
673
+
674
+ def __init__(self, name, fmt=":f"):
675
+ self.name = name
676
+ self.fmt = fmt
677
+ self.reset()
678
+
679
+ def reset(self):
680
+ self.val = 0
681
+ self.avg = 0
682
+ self.sum = 0
683
+ self.count = 0
684
+
685
+ def update(self, val, n=1):
686
+ self.val = val
687
+ self.sum += val * n
688
+ self.count += n
689
+ self.avg = self.sum / self.count
690
+
691
+ def __str__(self):
692
+ fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
693
+ return fmtstr.format(**self.__dict__)
694
+
695
+
696
+ class ProgressMeter(object):
697
+ def __init__(self, num_batches, meters, prefix=""):
698
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
699
+ self.meters = meters
700
+ self.prefix = prefix
701
+
702
+ def display(self, batch):
703
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
704
+ entries += [str(meter) for meter in self.meters]
705
+ print("\t".join(entries))
706
+
707
+ def _get_batch_fmtstr(self, num_batches):
708
+ num_digits = len(str(num_batches // 1))
709
+ fmt = "{:" + str(num_digits) + "d}"
710
+ return "[" + fmt + "/" + fmt.format(num_batches) + "]"
711
+
712
+
713
+ @torch.no_grad()
714
+ def extract_features(model, data_loader, use_cuda=True, use_fp16=False, eval_embed='head'):
715
+ metric_logger = MetricLogger(delimiter=" ")
716
+ features = None
717
+ # count = 0
718
+ for samples, *_, index in metric_logger.log_every(data_loader, 100):
719
+ # print(f'At the index {index}')
720
+ samples = samples.cuda(non_blocking=True)
721
+ index = index.cuda(non_blocking=True)
722
+ if use_fp16:
723
+ with torch.cuda.amp.autocast():
724
+ bb_feats, cont_feats, style_feats = model(samples)
725
+
726
+ if eval_embed == 'backbone':
727
+ feats = bb_feats.clone()
728
+ else:
729
+ feats = style_feats.clone()
730
+
731
+ else:
732
+ bb_feats, cont_feats, style_feats = model(samples)
733
+ if eval_embed == 'backbone':
734
+ feats = bb_feats.clone()
735
+ else:
736
+ feats = style_feats.clone()
737
+ # init storage feature matrix
738
+ if dist.get_rank() == 0 and features is None:
739
+ features = torch.zeros(len(data_loader.dataset), feats.shape[-1], dtype=feats.dtype)
740
+ if use_cuda:
741
+ features = features.cuda(non_blocking=True)
742
+ print(f"Storing features into tensor of shape {features.shape}")
743
+
744
+ # get indexes from all processes
745
+ y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device)
746
+ y_l = list(y_all.unbind(0))
747
+ y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
748
+ y_all_reduce.wait()
749
+ index_all = torch.cat(y_l)
750
+
751
+ # share features between processes
752
+ feats_all = torch.empty(
753
+ dist.get_world_size(),
754
+ feats.size(0),
755
+ feats.size(1),
756
+ dtype=feats.dtype,
757
+ device=feats.device,
758
+ )
759
+ output_l = list(feats_all.unbind(0))
760
+ output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
761
+ output_all_reduce.wait()
762
+
763
+ # update storage feature matrix
764
+ if dist.get_rank() == 0:
765
+ if use_cuda:
766
+ features.index_copy_(0, index_all, torch.cat(output_l))
767
+ else:
768
+ features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
769
+ return features
770
+
771
+
772
+ # Copy from https://github.com/learn2phoenix/dynamicDistances/blob/main/metrics/metrics.py
773
+ class Metrics(object):
774
+ def __init__(self):
775
+ self.data = None
776
+
777
+ @staticmethod
778
+ def get_recall(preds, gts, topk=5):
779
+ preds = preds[:, :topk]
780
+ preds -= gts[:, None]
781
+ found = np.where(np.amin(np.absolute(preds), axis=1) == 0)[0]
782
+ return found.shape[0] / gts.shape[0]
783
+
784
+ @staticmethod
785
+ def get_mrr(preds, gts, topk=5):
786
+ preds = preds[:, :topk]
787
+ preds -= gts[:, None]
788
+ rows, cols = np.where(preds == 0)
789
+ _, unique_rows = np.unique(rows, return_index=True)
790
+ valid_cols = cols[unique_rows]
791
+ valid_cols += 1
792
+ return np.mean(1 / valid_cols)
793
+
794
+ @staticmethod
795
+ def get_map(preds, gts, topk=5):
796
+ preds = preds[:, :topk]
797
+ preds -= gts[:, None]
798
+ rows, cols = np.where(preds == 0)
799
+ _, unique_rows = np.unique(rows, return_index=True)
800
+ row_cols = np.split(cols, unique_rows)[1:]
801
+ row_cols = [np.hstack([x[0], np.diff(x), topk - x[-1]]) for x in row_cols]
802
+ row_cols = [np.pad(x, (0, topk + 1 - x.shape[0]), 'constant', constant_values=(0, 0)) for x in row_cols]
803
+ precision = np.asarray([np.repeat(np.arange(topk + 1), x) / np.arange(1, topk + 1) for x in row_cols])
804
+ return np.sum(np.mean(precision, axis=1)) / preds.shape[0]
805
+
806
+ @staticmethod
807
+ def get_recall_bin(preds, topk=5):
808
+ # preds is a binary matrix of size Q x K
809
+ preds = preds[:, :topk]
810
+ found = np.where(np.amax(preds, axis=1) == True)[0]
811
+ return found.shape[0] / preds.shape[0]
812
+
813
+ @staticmethod
814
+ def get_mrr_bin(preds, topk=5):
815
+ # preds is a binary matrix of size Q x K
816
+ preds = preds[:, :topk]
817
+ rows, cols = np.where(preds)
818
+ _, unique_rows = np.unique(rows, return_index=True)
819
+ valid_cols = cols[unique_rows]
820
+ valid_cols += 1
821
+ return np.mean(1 / valid_cols)
822
+
823
+ @staticmethod
824
+ def get_map_bin(preds, topk=5):
825
+ # preds is a binary matrix of size Q x K
826
+ preds = preds[:, :topk]
827
+ rows, cols = np.where(preds)
828
+ _, unique_rows = np.unique(rows, return_index=True)
829
+ row_cols = np.split(cols, unique_rows)[1:]
830
+ row_cols = [np.hstack([x[0], np.diff(x), topk - x[-1]]) for x in row_cols]
831
+ row_cols = [np.pad(x, (0, topk + 1 - x.shape[0]), 'constant', constant_values=(0, 0)) for x in row_cols]
832
+ precision = np.asarray([np.repeat(np.arange(topk + 1), x) / np.arange(1, topk + 1) for x in row_cols])
833
+ return np.sum(np.mean(precision, axis=1)) / preds.shape[0]
834
+
835
+
836
+ class GatherLayer(torch.autograd.Function):
837
+ """Gather tensors from all process, supporting backward propagation.
838
+ """
839
+
840
+ @staticmethod
841
+ def forward(ctx, input):
842
+ ctx.save_for_backward(input)
843
+ output = [torch.zeros_like(input) \
844
+ for _ in range(dist.get_world_size())]
845
+ dist.all_gather(output, input)
846
+ return tuple(output)
847
+
848
+ @staticmethod
849
+ def backward(ctx, *grads):
850
+ input, = ctx.saved_tensors
851
+ grad_out = torch.zeros_like(input)
852
+ grad_out[:] = grads[dist.get_rank()]
853
+ return grad_out
CSD/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 learn2phoenix
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
CSD/README.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Measuring Style Similarity in Diffusion Models
2
+ Check out the paper here - [arxiv](https://arxiv.org/abs/2404.01292).
3
+
4
+ ![alt text](github_teaser.jpg "Generations from Stable Diffusion and corresponding matches from LAION-Styles split")
5
+
6
+ ## Create and activate the environment
7
+
8
+ ```
9
+ conda env create -f environment.yml
10
+ conda activate style
11
+ ```
12
+
13
+ ## Download the pretrained weights for the CSD model
14
+
15
+ Please download the CSD model (ViT-L) weights [here](https://drive.google.com/file/d/1FX0xs8p-C7Ob-h5Y4cUhTeOepHzXv_46/view?usp=sharing).
16
+
17
+
18
+ ## Download the pretrained weights for the baseline models
19
+
20
+ You need these only if you want to test the baseline numbers. For `CLIP` and `DINO`, pretrained weights will be downloaded automatically. For `SSCD` and `MoCo`, please download the weights
21
+ from the links below and put them in `./pretrainedmodels` folder.
22
+
23
+ * SSCD: [resnet50](https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_mixup.torchscript.pt)
24
+ * MoCO: [ViT-B](https://dl.fbaipublicfiles.com/moco-v3/vit-b-300ep/vit-b-300ep.pth.tar)
25
+
26
+
27
+
28
+ ## Download the WikiArt dataset
29
+ WikiArt can be downloaded from [here](https://drive.google.com/file/d/1vTChp3nU5GQeLkPwotrybpUGUXj12BTK/view?usp=drivesdk0)
30
+ or [here1](http://web.fsktm.um.edu.my/~cschan/source/ICIP2017/wikiart.zip)
31
+
32
+ After dataset is downloaded please put `./wikiart.csv` in the parent directory of the dataset. The final directory structure should look like this:
33
+ ```
34
+ path/to/WikiArt
35
+ ├── wikiart
36
+    ├── Abstract_Expressionism
37
+       ├── <filename>.jpg
38
+    ├── ...
39
+ └── wikiart.csv
40
+ ```
41
+
42
+ Also, make sure that you add a column `path` in the `wikiart.csv` file which contains the absolute path to the image.
43
+
44
+ ## Generate the embeddings
45
+
46
+ Once WikiArt dataset is set up, you can generate the CSD embeddings by running the following command. Please adjust
47
+ the `--data-dir` and `--embed_dir` accordingly. You should also adjust the batch size `--b` and number of workers `--j`
48
+ according to your machine. The command to generate baseline embeddings is same, you just need to change the `--pt_style`
49
+ with any of the following: `clip`, `dino`, `sscd`, `moco`.
50
+
51
+ ```angular2html
52
+ python main_sim.py --dataset wikiart -a vit_large --pt_style csd --feattype normal --world-size 1
53
+ --dist-url tcp://localhost:6001 -b 128 -j 8 --embed_dir ./embeddings --data-dir <path to WikiArt dataset>
54
+ --model_path <path to CSD weights>
55
+ ```
56
+
57
+ ## Evaluate
58
+ Once you've generated the embeddings, run the following command:
59
+
60
+ ```angular2html
61
+ python search.py --mode artist --dataset wikiart --chunked --query-chunk-dir <path to query embeddings above>
62
+ --database-chunk-dir <path to database embeddings above> --topk 1 10 100 1000 --method IP --data-dir <path to WikiArt dataset>
63
+ ```
64
+
65
+ ## Train CSD on LAION-Styles
66
+
67
+ You can also train style descriptors for your own datasets. A sample code for training on LAION-styles dataset is provided below.
68
+
69
+ We have started to release the **Contra-Styles** (referred to as LAION-Styles in the paper) dataset. The dataset is available [here](https://huggingface.co/datasets/tomg-group-umd/ContraStyles)
70
+ and will keep getting updated over the next few days as we are running profanity checks through NSFW and PhotoDNA. We will update here once the dataset has been completely uploaded.
71
+
72
+ ```
73
+ export PYTHONPATH="$PWD:$PYTHONPATH"
74
+
75
+ torchrun --standalone --nproc_per_node=4 CSD/train_csd.py --arch vit_base -j 8 -b 32 --maxsize 512 --resume_if_available --eval_k 1 10 100 --use_fp16 --use_distributed_loss --train_set laion_dedup --train_path <PATH to LAION-Styles> --eval_path <PATH to WikiArt/some val set> --output_dir <PATH to save checkpoint>
76
+ ```
77
+
78
+ ## Pending items
79
+
80
+ We will soon release the code to compute the artists' prototypical style representations and compute similarity score against any given generation. ETA end of June'24.
81
+
82
+ ## Cite us
83
+
84
+ ```
85
+ @article{somepalli2024measuring,
86
+ title={Measuring Style Similarity in Diffusion Models},
87
+ author={Somepalli, Gowthami and Gupta, Anubhav and Gupta, Kamal and Palta, Shramay and Goldblum, Micah and Geiping, Jonas and Shrivastava, Abhinav and Goldstein, Tom},
88
+ journal={arXiv preprint arXiv:2404.01292},
89
+ year={2024}
90
+ }
91
+ ```
CSD/__init__.py ADDED
File without changes
CSD/artists_400.txt ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ leonid afremov
2
+ georges seurat
3
+ amedeo modigliani
4
+ Alan bean
5
+ scott naismith
6
+ antoine blanchard
7
+ frederic remington
8
+ Artgerm
9
+ Greg Rutkowski
10
+ mucha
11
+ Alphonse Mucha
12
+ WLOP
13
+ Ilya Kuvshinov
14
+ stanley artgerm
15
+ Makoto Shinkai
16
+ rossdraws
17
+ James Jean
18
+ Magali Villeneuve
19
+ Donato Giancola
20
+ loish
21
+ Ruan Jia
22
+ Studio Ghibli
23
+ Beeple
24
+ Ross Tran
25
+ Beksinski
26
+ Peter Mohrbacher
27
+ Marc Simonetti
28
+ Tom Bagshaw
29
+ Craig Mullins
30
+ Pixar
31
+ Boris Vallejo
32
+ joseph christian leyendecker
33
+ Charlie Bowater
34
+ Dan Mumford
35
+ Moebius
36
+ amano
37
+ tomasz alen kopera
38
+ Norman Rockwell
39
+ Thomas Kinkade
40
+ Giger
41
+ Michael Garmash
42
+ Victo Ngai
43
+ Lois van Baarle
44
+ Klimt
45
+ RHADS
46
+ Jordan Grimmer
47
+ Bruce Pennington
48
+ James Gurney
49
+ disney
50
+ Michael Whelan
51
+ Frank Frazetta
52
+ Tim Hildebrandt
53
+ Raphael
54
+ Caravaggio
55
+ James Gilleard
56
+ Simon Stalenhag
57
+ Gil Elvgren
58
+ Zdzislaw Beksinski
59
+ Edward Hopper
60
+ gaston bussiere
61
+ Larry Elmore
62
+ Fenghua Zhong
63
+ Mike Mignola
64
+ Karol Bak
65
+ Francis Bacon
66
+ Brom
67
+ Syd Mead
68
+ Gustav Klimt
69
+ Jesper Ejsing
70
+ Noah Bradley
71
+ Steve McCurry
72
+ Gustave Dore
73
+ Atey Ghailan
74
+ gustav dore
75
+ Alex Ross
76
+ Mark Brooks
77
+ albert aublet
78
+ Raphael Lacoste
79
+ John Collier
80
+ Alex Grey
81
+ Rembrandt
82
+ Eddie Mendoza
83
+ Julie Bell
84
+ Android Jones
85
+ Miyazaki
86
+ Greg Hildebrandt
87
+ Gerald Brom
88
+ Darek Zabrocki
89
+ Titian
90
+ Jim Burns
91
+ Guy Denning
92
+ Mattias Adolfsson
93
+ Raymond Swanland
94
+ Ernst
95
+ Wes Anderson
96
+ Feng Zhu
97
+ Luis Royo
98
+ Justin Gerard
99
+ Hayao Miyazaki
100
+ Takato Yamamoto
101
+ Tyler Edlin
102
+ Bob Eggleton
103
+ Jakub Rozalski
104
+ Tuomas Korpi
105
+ Maxfield Parrish
106
+ Kilian Eng
107
+ Rebecca Guay
108
+ ferdinand knab
109
+ Jamie Wyeth
110
+ John Berkey
111
+ John Singer Sargent
112
+ Rene Magritte
113
+ Zaha Hadid
114
+ Josan Gonzalez
115
+ Shaddy Safadi
116
+ Carl Spitzweg
117
+ hajime sorayama
118
+ Conrad Roset
119
+ lovecraft
120
+ Simon Stålenhag
121
+ Masamune Shirow
122
+ Leonardo da Vinci
123
+ Gediminas Pranckevicius
124
+ charles vess
125
+ Jason Chan
126
+ Anton Fadeev
127
+ Albert Bierstadt
128
+ hr giger
129
+ Dali
130
+ Roger Dean
131
+ Zdzisław Beksiński
132
+ Bayard Wu
133
+ Ivan Shishkin
134
+ Bob Ross
135
+ Andreas Rocha
136
+ Warhol
137
+ Joe Fenton
138
+ Hiroshi Yoshida
139
+ Goro Fujita
140
+ Cedric Peyravernay
141
+ Jan van Eyck
142
+ Ismail Inceoglu
143
+ Ralph Horsley
144
+ Andy Warhol
145
+ Tomer Hanuka
146
+ Jean Giraud
147
+ Gustave Courbet
148
+ Roberto Ferri
149
+ Dustin Nguyen
150
+ Mark Arian
151
+ Louis Wain
152
+ Ernst Haeckel
153
+ Ivan Aivazovsky
154
+ Salvador Dali
155
+ Arthur Rackham
156
+ Louis Comfort Tiffany
157
+ Maciej Kuciara
158
+ John Harris
159
+ Andrew Wyeth
160
+ Mark Ryden
161
+ Junji Ito
162
+ sung choi
163
+ Alan Lee
164
+ sylvain sarrailh
165
+ Gaudi
166
+ Max Ernst
167
+ Filip Hodas
168
+ Daarken
169
+ Ralph McQuarrie
170
+ Sailor Moon
171
+ roger deakins
172
+ Rosa Bonheur
173
+ Brad Kunkle
174
+ Lee Madgwick
175
+ Caspar David Friedrich
176
+ Alberto Vargas
177
+ Chris Foss
178
+ Alena Aenami
179
+ Ian McQue
180
+ Wadim Kashin
181
+ Jean Delville
182
+ Fra Angelico
183
+ Peter Elson
184
+ Martin Johnson Heade
185
+ John Howe
186
+ Anna Dittmann
187
+ Zack Snyder
188
+ Jim Lee
189
+ Hieronymus Bosch
190
+ Josephine Wall
191
+ jessica rossier
192
+ Michelangelo
193
+ Michaelangelo
194
+ Ryohei Hase
195
+ Ilya Repin
196
+ Annie Leibovitz
197
+ Picasso
198
+ Stephan Martiniere
199
+ Frank Stella
200
+ Eugene von Guerard
201
+ Hokusai
202
+ Alexander McQueen
203
+ Tyler Jacobson
204
+ Monet
205
+ William Turner
206
+ Van Gogh
207
+ Anne Stokes
208
+ Jeff Koons
209
+ Frank Miller
210
+ Anton Pieck
211
+ Christopher Balaskas
212
+ Ernst Fuchs
213
+ Thomas Cole
214
+ Carne Griffiths
215
+ Mikhail Vrubel
216
+ John William Waterhouse
217
+ John William Godward
218
+ Arcimboldo
219
+ Vermeer
220
+ Daniel Merriam
221
+ James Paick
222
+ Takashi Murakami
223
+ Murakami
224
+ Jan Matejko
225
+ Banksy
226
+ Cyril Rolando
227
+ Amanda Sage
228
+ Miho Hirano
229
+ Eric Zener
230
+ Remedios Varo
231
+ Liam Wong
232
+ Art Green
233
+ Ed Roth
234
+ Drew Struzan
235
+ Jacek Yerka
236
+ Kelly McKernan
237
+ Raja Ravi Varma
238
+ ashley wood
239
+ Kandinsky
240
+ Sam Spratt
241
+ Rolf Armstrong
242
+ Bauhaus
243
+ Esao Andrews
244
+ ESAO
245
+ Richter
246
+ Gertrude Abercrombie
247
+ Yuumei
248
+ Jack Kirby
249
+ Victor Nizovtsev
250
+ Roy Lichtenstein
251
+ Lichtenstein
252
+ Harumi Hironaka
253
+ Paul Lehr
254
+ Les Edwards
255
+ Mike Winkelmann
256
+ Dan Luvisi
257
+ Art Frahm
258
+ ridley scott
259
+ Diego Rivera
260
+ irakli nadar
261
+ Dante Gabriel Rossetti
262
+ Francisco Goya
263
+ Evelyn De Morgan
264
+ Frederic Edwin Church
265
+ Frederick Edwin Church
266
+ Jon Foster
267
+ John Carpenter
268
+ Giuseppe Arcimboldo
269
+ Marcel Duchamp
270
+ MC Escher
271
+ Giorgio de Chirico
272
+ Frans Hals
273
+ Winslow Homer
274
+ adrian ghenie
275
+ Gerhard Richter
276
+ Cecil Beaton
277
+ Martine Johanna
278
+ Tom Whalen
279
+ Brian Froud
280
+ Sandra Chevrier
281
+ Vincent Van Gogh
282
+ Yasutomo Oka
283
+ Gregory Crewdson
284
+ George Stubbs
285
+ Eyvind Earle
286
+ Gustave Baumann
287
+ Yanjun Cheng
288
+ Tran Nguyen
289
+ Marina Abramović
290
+ Cy Twombly
291
+ Anselm Kiefer
292
+ John James Audubon
293
+ Chris Moore
294
+ Hasui Kawase
295
+ Scott Listfield
296
+ Hugh Ferriss
297
+ Claude Monet
298
+ Jeff Easley
299
+ Michael Komarck
300
+ Jeremy Geddes
301
+ Yves Tanguy
302
+ Svetlin Velinov
303
+ Lucian Freud
304
+ Viktor Vasnetsov
305
+ Gustave Doré
306
+ Hikari Shimoda
307
+ Edmund Dulac
308
+ William Blake
309
+ Thomas Eakins
310
+ Frederic Church
311
+ Gian Lorenzo Bernini
312
+ Bill Sienkiewicz
313
+ David Hockney
314
+ Lucas Graciano
315
+ national geographic
316
+ Frida Kahlo
317
+ Kahlo
318
+ Jaime Jones
319
+ Donald Judd
320
+ Kawase Hasui
321
+ Tim Okamura
322
+ Anton Otto Fischer
323
+ Tom Lovell
324
+ Richard Hamilton
325
+ Emiliano Ponzi
326
+ Charles Marion Russell
327
+ Ina Wong
328
+ Adam Paquette
329
+ Otto Dix
330
+ Gabriel Dawe
331
+ Mary Cassatt
332
+ Arkhip Kuindzhi
333
+ Jason Felix
334
+ Piranesi
335
+ Marianne North
336
+ Peter Lindbergh
337
+ Georges de La Tour
338
+ Francis Picabia
339
+ Kay Nielsen
340
+ Sanford Robinson Gifford
341
+ Hans Baluschek
342
+ Audrey Kawasaki
343
+ Mark Rothko
344
+ Frank Auerbach
345
+ Winston Churchill
346
+ Cynthia Sheppard
347
+ Chris Rahn
348
+ Todd Lockwood
349
+ Harry Clarke
350
+ Coby Whitmore
351
+ Margaret Keane
352
+ Man Ray
353
+ Hubert Robert
354
+ Dorothea Tanning
355
+ Ivan Bilibin
356
+ Austin Osman Spare
357
+ Paul Klee
358
+ Frederic Leighton
359
+ Alfonse Mucha
360
+ Fernando Botero
361
+ Marco Mazzoni
362
+ Evgeny Lushpin
363
+ John Atkinson Grimshaw
364
+ Peter Paul Rubens
365
+ Thomas Lawrence
366
+ Yasar Vurdem
367
+ Isaac Levitan
368
+ Asher Brown Durand
369
+ Yoann Lossel
370
+ Henry Ossawa Tanner
371
+ Bill Ward
372
+ Jean Arp
373
+ Jenny Saville
374
+ Katsushika Hokusai
375
+ Kim Keever
376
+ Pablo Picasso
377
+ Robert Delaunay
378
+ Delaunay
379
+ Chris Rallis
380
+ Oleg Oprisco
381
+ Anka Zhuravleva
382
+ Walt Disney
383
+ Tom Chambers
384
+ Salvador Dalí
385
+ Dalí
386
+ Edward Gorey
387
+ William Morris
388
+ Takeshi Obata
389
+ Juan Luna
390
+ Christophe Vacher
391
+ Grzegorz Rutkowski
392
+ Tamara de Lempicka
393
+ Tadao Ando
394
+ Peter Gric
395
+ sparth
396
+ Leonora Carrington
397
+ Mœbius
398
+ Constant
399
+ John Anster Fitzgerald
400
+ Patrick Nagel
CSD/data/laion.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ from torch.utils.data import Dataset
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ import pickle
7
+ import vaex as vx
8
+
9
+
10
+ def create_laion_cache(root_dir, anno_dir, keys=['artist', 'medium', 'movement']):
11
+ # -projects/diffusion_rep/data/laion_style_subset
12
+ # read all the picke files in the anno_dir
13
+ paths = []
14
+ labels = [] # list of lists since each image can have multiple labels
15
+ labels_to_index = {} # dictionary that maps each label to an list of image indices
16
+
17
+ keys_offset = {k: 1000000 * i for i, k in enumerate(keys)} # offset each key labels by a large number
18
+
19
+ str_to_list = lambda x, offset: [offset + int(a) for a in x.strip().split(',') if len(a) > 0]
20
+ for f in tqdm(os.listdir(anno_dir)):
21
+ if f.endswith('.pkl'):
22
+ with open(os.path.join(anno_dir, f), 'rb') as tmp:
23
+ ann = pickle.load(tmp)
24
+
25
+ for i, path in enumerate(ann['key']):
26
+ cur_label = {
27
+ k: str_to_list(ann[k][i], keys_offset[k])
28
+ for k in keys
29
+ }
30
+ cur_label = sum(cur_label.values(), [])
31
+ if len(cur_label) > 0:
32
+ image_path = os.path.join(root_dir, 'data', path[:5], path + '.jpg')
33
+
34
+ # if not os.path.exists(image_path):
35
+ # continue
36
+
37
+ paths.append(image_path)
38
+ labels.append(set(cur_label))
39
+ for l in cur_label: labels_to_index.setdefault(l, []).append(i)
40
+
41
+ cache_path = os.path.join(anno_dir, '_'.join(keys) + '.cache')
42
+ with open(cache_path, 'wb') as tmp:
43
+ pickle.dump((paths, labels, labels_to_index), tmp)
44
+ return paths, labels, labels_to_index
45
+
46
+
47
+ class LAION(Dataset):
48
+ def __init__(self, root_dir, anno_dir, split='database', transform=None,
49
+ keys=['artist', 'medium', 'movement'],
50
+ min_images_per_label=1, max_images_per_label=100000,
51
+ num_queries_per_label=10, maxsize=None, model_type='dann'):
52
+ # -projects/diffusion_rep/data/laion_style_subset
53
+ self.root_dir = root_dir
54
+ self.transform = transform
55
+ self.model_type = model_type
56
+
57
+ # read all the picke files in the anno_dir
58
+ paths = []
59
+ labels = [] # list of lists since each image can have multiple labels
60
+ labels_to_index = {} # dictionary that maps each label to an list of image indices
61
+
62
+ cache_path = os.path.join(anno_dir, '_'.join(keys) + '.cache')
63
+ if os.path.exists(cache_path):
64
+ with open(cache_path, 'rb') as tmp:
65
+ paths, labels, labels_to_index = pickle.load(tmp)
66
+ else:
67
+ paths, labels, labels_to_index = create_laion_cache(root_dir, anno_dir, keys)
68
+
69
+ maxout_labels = [l for l, v in labels_to_index.items() if len(v) > max_images_per_label]
70
+ maxout_labels.append('') # Artificially add an empty label
71
+ print(f"Removing {len(maxout_labels)} tags with > {max_images_per_label} images")
72
+
73
+ minout_labels = [l for l, v in labels_to_index.items() if len(v) < min_images_per_label]
74
+ print(f"Removing {len(minout_labels)} tags with < {min_images_per_label} images")
75
+
76
+ # Get all possible tags
77
+ self.index_to_labels = list(set(labels_to_index.keys()) - set(maxout_labels) - set(minout_labels))
78
+ self.labels_to_index = {l: i for i, l in enumerate(self.index_to_labels)}
79
+
80
+ self.pathlist = []
81
+ self.labels = []
82
+ eye = np.eye(len(self.index_to_labels))
83
+ print("Filtering out labels")
84
+ for path, label in tqdm(zip(paths, labels)):
85
+ for l in maxout_labels:
86
+ if l in label:
87
+ label.remove(l)
88
+
89
+ for l in minout_labels:
90
+ if l in label:
91
+ label.remove(l)
92
+
93
+ if len(label) > 0:
94
+ self.pathlist.append(path)
95
+ cur_label = np.sum(eye[[self.labels_to_index[l] for l in label]], axis=0).astype(bool)
96
+ self.labels.append(cur_label)
97
+ self.labels = np.array(self.labels)
98
+
99
+ ## Split the dataset into index and query
100
+ keys_offset = {k: 1000000 * i for i, k in enumerate(keys)}
101
+ self.name_to_label = {}
102
+ for k in keys:
103
+ key_labels_path = os.path.join(
104
+ anno_dir, '../clip-interrogator/clip_interrogator/data',
105
+ k + "s_filtered_new.txt")
106
+ with open(os.path.join(key_labels_path)) as f:
107
+ for i, l in enumerate(f.readlines()):
108
+ self.name_to_label[l.strip().replace("@", " ")] = keys_offset[k] + i
109
+
110
+ with open(os.path.join(anno_dir, 'top612_artists_shortlist_400.txt'), 'r') as f:
111
+ q_names = [l.lower().strip() for l in f.readlines()]
112
+ q_labels = [self.name_to_label[n] for n in q_names]
113
+ q_index = [self.labels_to_index[l] for l in q_labels]
114
+
115
+ query_ind = np.unique(np.concatenate(
116
+ [np.where(self.labels[:, i])[0][:num_queries_per_label]
117
+ for i in q_index]))
118
+
119
+ if split == "database":
120
+ self.pathlist = [self.pathlist[i] for i in range(len(self.pathlist)) if i not in query_ind]
121
+ self.labels = np.delete(self.labels, query_ind, axis=0)
122
+ else:
123
+ self.pathlist = [self.pathlist[i] for i in query_ind]
124
+ self.labels = self.labels[query_ind]
125
+
126
+ self.namelist = list(map(lambda x: x.split('/')[-1], self.pathlist))
127
+ # Select maxsize number of images
128
+ if maxsize is not None:
129
+ ind = np.random.randint(0, len(self.pathlist), maxsize)
130
+ self.pathlist = [self.pathlist[i] for i in ind]
131
+ self.labels = self.labels[ind]
132
+ self.namelist = [self.namelist[i] for i in ind]
133
+
134
+ def __len__(self):
135
+ return len(self.pathlist)
136
+
137
+ def __getitem__(self, idx):
138
+ img_loc = self.pathlist[idx]
139
+ image = Image.open(img_loc).convert("RGB")
140
+
141
+ if self.transform:
142
+ images = self.transform(image)
143
+
144
+ style = self.labels[idx]
145
+ if self.model_type == 'dann':
146
+ return images, style, idx
147
+ else:
148
+ return images, idx
149
+
150
+
151
+ def create_laion_dedup_cache(dedup_dir):
152
+ # -projects/diffusion_rep/data/laion_style_subset/dedup_info
153
+ keys = None
154
+ labels = None
155
+ rejects = None
156
+ matching_info = None
157
+
158
+ files = [f for f in os.listdir(dedup_dir) if f.endswith('.parquet')]
159
+ for f in tqdm(sorted(files, key=lambda x: int(x.split('_')[2]))):
160
+ # Load dedup info
161
+ df = vx.open(os.path.join(dedup_dir, f))
162
+ if keys is None:
163
+ keys = df['name'].tolist()
164
+
165
+ # Updating reject information
166
+ cur_reject = df['matched'].to_numpy()
167
+ if rejects is not None:
168
+ rejects += cur_reject
169
+ else:
170
+ rejects = cur_reject
171
+
172
+ # Load labels
173
+ cur_labels = np.load(os.path.join(dedup_dir, f.replace('parquet', 'npz').replace('val_db', 'multilabel')))
174
+ cur_labels = cur_labels["arr_0"]
175
+ if labels is not None:
176
+ labels += cur_labels
177
+ else:
178
+ labels = cur_labels
179
+
180
+ # Load matching info
181
+ cur_matching_info = pickle.load(
182
+ open(os.path.join(dedup_dir, f.replace('parquet', 'pkl').replace('val_db', 'matching_info')), 'rb'))
183
+ if matching_info is not None:
184
+ matching_info.extend(cur_matching_info)
185
+ else:
186
+ matching_info = cur_matching_info
187
+
188
+ # Propagating labels
189
+ for i in tqdm(range(len(matching_info) - 1, -1, -1)):
190
+ labels[i] += np.sum(labels[matching_info[i], :], axis=0, dtype=bool)
191
+
192
+ cache_path = os.path.join(dedup_dir, 'joined.cache')
193
+ with open(cache_path, 'wb') as tmp:
194
+ pickle.dump((keys, labels, rejects), tmp)
195
+ return keys, labels, rejects
196
+
197
+
198
+ class LAIONDedup(Dataset):
199
+ def __init__(self, root_dir, anno_dir, transform=None, model_type='dann', eval_mode=False, artist_mode=False):
200
+ self.root_dir = root_dir
201
+ self.transform = transform
202
+ self.model_type = model_type
203
+
204
+ dedup_dir = os.path.join(anno_dir, 'dedup_info')
205
+ cache_path = os.path.join(dedup_dir, 'joined.cache')
206
+ if os.path.exists(cache_path):
207
+ with open(cache_path, 'rb') as tmp:
208
+ keys, labels, rejects = pickle.load(tmp)
209
+ else:
210
+ keys, labels, rejects = create_laion_dedup_cache(dedup_dir)
211
+
212
+ keys = np.array(keys)[~rejects]
213
+ self.pathlist = [os.path.join(root_dir, 'data', key[:5], key + '.jpg') for key in keys]
214
+ self.labels = labels[~rejects]
215
+ self.namelist = list(map(lambda x: x.split('/')[-1], self.pathlist))
216
+
217
+ if eval_mode:
218
+ q_dset = LAION(root_dir, anno_dir, split='query')
219
+ self.query_db = vx.from_arrays(
220
+ name=[x.split('.')[0] for x in q_dset.namelist],
221
+ multilabel=q_dset.labels)
222
+
223
+ self.name_to_label = q_dset.name_to_label
224
+ self.labels_to_index = q_dset.labels_to_index
225
+ self.index_to_labels = q_dset.index_to_labels
226
+
227
+ self.val_db = vx.from_arrays(
228
+ name=keys.tolist(),
229
+ multilabel=self.labels)
230
+
231
+ if artist_mode:
232
+ # Filtering the db to include images which have hit on an artist
233
+ artist_inds = []
234
+ for label, index in self.labels_to_index.items():
235
+ if label < 1000000:
236
+ artist_inds.append(index)
237
+ artist_labels = self.labels[:, artist_inds]
238
+ artist_images = np.argwhere(np.sum(artist_labels, axis=1) > 0)
239
+ self.val_db = self.val_db.take(artist_images.squeeze()).extract()
240
+
241
+ def __len__(self):
242
+ return len(self.pathlist)
243
+
244
+ def __getitem__(self, idx):
245
+ img_loc = self.pathlist[idx]
246
+ image = Image.open(img_loc).convert("RGB")
247
+
248
+ if self.transform:
249
+ images = self.transform(image)
250
+
251
+ style = self.labels[idx]
252
+ if self.model_type == 'dann':
253
+ return images, style, idx
254
+ else:
255
+ return images, idx
256
+
257
+ def get_query_col(self, col):
258
+ return np.asarray(self.query_db[col].tolist())
259
+
260
+ def get_val_col(self, col):
261
+ return np.asarray(self.val_db[col].tolist())
262
+
263
+
264
+ class SDSynth400:
265
+ def __init__(self, root_dir, query_split='user_caps', transform=None, eval_mode=False):
266
+ self.root_dir = root_dir
267
+ self.transform = transform
268
+ self.query_split = query_split
269
+ assert query_split in ['user_caps', 'simple_caps', 'woman_caps', 'house_caps', 'dog_caps']
270
+ assert os.path.exists(os.path.join(root_dir, f'{query_split}.csv'))
271
+ annotations = vx.from_csv(f'{self.root_dir}/{query_split}.csv')
272
+
273
+ self.pathlist = annotations['filepath'].tolist()
274
+ self.namelist = list(map(lambda x: x.split('/')[-1], self.pathlist))
275
+
276
+ # Dummy variables, not actually needed
277
+ self.query_images = []
278
+ self.val_images = []
279
+
280
+ if eval_mode:
281
+ data_dir = '-datasets/improved_aesthetics_6plus'
282
+ anno_dir = '-projects/diffusion_rep/data/laion_style_subset'
283
+ val_dset = LAIONDedup(data_dir, anno_dir, transform=transform, eval_mode=True, artist_mode=True)
284
+ # val_dset = LAION(data_dir, anno_dir, transform=transform)
285
+ # Needed for search code
286
+ filenames = [f.split('.')[0] for f in self.namelist]
287
+ q_names = [[l.lower().strip() for l in eval(label)] for label in annotations['labels'].tolist()]
288
+ q_labels = [[val_dset.name_to_label[n] for n in names if n in val_dset.name_to_label] for names in q_names]
289
+ q_index = [[val_dset.labels_to_index[l] for l in labels if l in val_dset.labels_to_index] for labels in
290
+ q_labels]
291
+
292
+ eye = np.eye(len(val_dset.index_to_labels))
293
+ q_binlabels = [np.sum(eye[ind], axis=0).astype(bool) for ind in q_index]
294
+ self.query_db = vx.from_arrays(
295
+ name=filenames, multilabel=q_binlabels)
296
+ self.val_db = val_dset.val_db
297
+
298
+ def __len__(self):
299
+ return len(self.namelist)
300
+
301
+ def __getitem__(self, idx):
302
+ img_loc = self.pathlist[idx]
303
+ image = Image.open(img_loc).convert("RGB")
304
+ if self.transform:
305
+ image = self.transform(image)
306
+
307
+ return image, idx
308
+
309
+ def get_query_col(self, col):
310
+ return np.asarray(self.query_db[col].tolist())
311
+
312
+ def get_val_col(self, col):
313
+ return np.asarray(self.val_db[col].tolist())
314
+
315
+
316
+ if __name__ == "__main__":
317
+ # dset = WikiArt(
318
+ # "-projects/diffusion_rep/data/wikiart", 'database')
319
+
320
+ dset = LAION(
321
+ "-datasets/improved_aesthetics_6plus",
322
+ "-projects/diffusion_rep/data/laion_style_subset",
323
+ split='database')
324
+ print(f"{len(dset)} images in the dataset")
325
+
326
+ index_to_labels = []
327
+ index_to_keys = []
328
+ index_to_texts = []
329
+ label_to_name = {v: k for k, v in dset.name_to_label.items()}
330
+ for label in dset.index_to_labels:
331
+ index_to_texts.append(label_to_name[label])
332
+ index_to_labels.append(label)
333
+ if label < 1000000:
334
+ index_to_keys.append('artist')
335
+ elif label < 2000000:
336
+ index_to_keys.append('medium')
337
+ else:
338
+ index_to_keys.append('movement')
339
+
340
+ path = "-projects/diffusion_rep/data/laion_style_subset/index_to_labels_keys_texts.pkl"
341
+ with open(path, 'wb') as tmp:
342
+ pickle.dump((index_to_labels, index_to_keys, index_to_texts), tmp)
343
+
344
+ # dset = LAION(
345
+ # "-datasets/improved_aesthetics_6plus",
346
+ # "-projects/diffusion_rep/data/laion_style_subset",
347
+ # split='query',
348
+ # min_images_per_label=10,
349
+ # max_images_per_label=100000)
350
+
351
+ # print(f"{len(dset)} images in the dataset")
352
+
353
+ # dset = LAIONDedup(
354
+ # "-datasets/improved_aesthetics_6plus",
355
+ # "-projects/diffusion_rep/data/laion_style_subset",
356
+ # eval_mode=True)
CSD/data/wikiart.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import os
3
+ import sys
4
+ import os.path as osp
5
+ from PIL import Image
6
+ from torch.utils.data import Dataset
7
+ import pandas as pd
8
+ import vaex as vx
9
+ import numpy as np
10
+
11
+
12
+ sys.path.insert(0, str(pathlib.Path(__file__).parent.resolve()))
13
+
14
+
15
+ class WikiArt(object):
16
+ def __init__(self, root_dir):
17
+ assert osp.exists(osp.join(root_dir, 'wikiart.csv'))
18
+ self.root_dir = root_dir
19
+ annotations = vx.from_csv(f'{self.root_dir}/wikiart.csv')
20
+ acceptable_artists = list(set(annotations[annotations['split'] == 'database']['artist'].tolist()))
21
+ temprepo = annotations[annotations['artist'].isin(acceptable_artists)]
22
+ self.query_images = temprepo[temprepo['split'] == 'query']['name'].tolist()
23
+ self.val_images = temprepo[temprepo['split'] == 'database']['name'].tolist()
24
+ self.query_db = annotations[annotations['name'].isin(self.query_images)]
25
+ self.val_db = annotations[annotations['name'].isin(self.val_images)]
26
+ self.query_db['name'] = self.query_db['name'].apply(lambda x: '.'.join(x.split('.')[:-1]))
27
+ self.val_db['name'] = self.val_db['name'].apply(lambda x: '.'.join(x.split('.')[:-1]))
28
+
29
+ def get_query_col(self, col):
30
+ return np.asarray(self.query_db[col].tolist())
31
+
32
+ def get_val_col(self, col):
33
+ return np.asarray(self.val_db[col].tolist())
34
+
35
+
36
+ class WikiArtD(Dataset):
37
+ def __init__(self, root_dir, split, transform=None):
38
+ self.root_dir = root_dir
39
+ self.transform = transform
40
+ self.split = split
41
+ assert osp.exists(osp.join(root_dir, 'wikiart.csv'))
42
+ annotations = vx.from_csv(f'{self.root_dir}/wikiart.csv')
43
+ acceptable_artists = list(set(annotations[annotations['split'] == 'database']['artist'].tolist()))
44
+ temprepo = annotations[annotations['artist'].isin(acceptable_artists)]
45
+ self.pathlist = temprepo[temprepo['split'] == split]['path'].tolist()
46
+
47
+ self.namelist = list(map(lambda x: x.split('/')[-1], self.pathlist))
48
+
49
+ def __len__(self):
50
+ return len(self.namelist)
51
+
52
+ def __getitem__(self, idx):
53
+ img_loc = self.pathlist[idx] # os.path.join(self.root_dir, self.split,self.artists[idx] ,self.pathlist[idx])
54
+ image = Image.open(img_loc).convert("RGB")
55
+ if self.transform:
56
+ image = self.transform(image)
57
+
58
+ return image, idx
59
+
60
+
61
+ class WikiArtTrain(Dataset):
62
+ def __init__(self, root_dir, split='database', transform=None, maxsize=None):
63
+ self.root_dir = root_dir
64
+ self.transform = transform
65
+ self.split = split
66
+ assert os.path.exists(os.path.join(root_dir, 'wikiart.csv'))
67
+ annotations = pd.read_csv(f'{self.root_dir}/wikiart.csv')
68
+ acceptable_artists = list(
69
+ set(annotations[annotations['split'] == 'database']['artist'].tolist())
70
+ )
71
+ temprepo = annotations[annotations['artist'].isin(acceptable_artists)]
72
+ self.pathlist = temprepo[temprepo['split'] == split]['path'].tolist()
73
+ self.labels = temprepo[temprepo['split'] == split]['artist'].tolist()
74
+
75
+ self.artist_to_index = {artist: i for i, artist in enumerate(acceptable_artists)}
76
+ self.index_to_artist = acceptable_artists
77
+
78
+ # Convert labels to one-hot
79
+ self.labels = list(map(lambda x: self.artist_to_index[x], self.labels))
80
+ self.labels = np.eye(len(acceptable_artists))[self.labels].astype(bool)
81
+ self.namelist = list(map(lambda x: x.split('/')[-1], self.pathlist))
82
+
83
+ # Select maxsize number of images
84
+ if maxsize is not None:
85
+ ind = np.random.randint(0, len(self.namelist), maxsize)
86
+ self.namelist = [self.namelist[i] for i in ind]
87
+ self.pathlist = [self.pathlist[i] for i in ind]
88
+ self.labels = self.labels[ind]
89
+
90
+ def __len__(self):
91
+ return len(self.namelist)
92
+
93
+ def __getitem__(self, idx):
94
+
95
+ img_loc = self.pathlist[idx]
96
+ image = Image.open(img_loc).convert("RGB")
97
+
98
+ if self.transform:
99
+ images = self.transform(image)
100
+
101
+ artist = self.labels[idx]
102
+ return images, artist, idx
CSD/embeddings/.gitkeep ADDED
File without changes
CSD/environment.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: style
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ - conda-forge
6
+ dependencies:
7
+ - pillow
8
+ - pip
9
+ - python=3.9
10
+ - pytorch=*=*cuda11.3*
11
+ - cudatoolkit>=11.3
12
+ - scipy
13
+ - torchvision
14
+ - jupyterlab
15
+ - ipywidgets
16
+ - scikit-image
17
+ - faiss-gpu
18
+ - tensorboard
19
+ - pip:
20
+ - git+https://github.com/openai/CLIP.git
21
+ - pandas
22
+ - ipdb
23
+ - wandb
24
+ - timm==0.6.12
25
+ - matplotlib
26
+ - einops
27
+ - vaex
28
+ - seaborn
29
+ - scikit-learn
CSD/github_teaser.jpg ADDED
CSD/laion-styles-subset-tags.txt ADDED
@@ -0,0 +1,3480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ a j casson
2
+ aaron bohrod
3
+ aaron douglas
4
+ aaron jasinski
5
+ aaron miller
6
+ aaron nagel
7
+ abbott handerson thayer
8
+ abdur rahman chughtai
9
+ abraham bloemaert
10
+ abraham de vries
11
+ abraham hondius
12
+ abraham mignon
13
+ abraham storck
14
+ abraham van den tempel
15
+ abraham willaerts
16
+ abram arkhipov
17
+ adam bruce thomson
18
+ adam elsheimer
19
+ adam paquette
20
+ adam rex
21
+ adolf dietrich
22
+ adolf schrödter
23
+ adolf ulric wertmüller
24
+ adolph gottlieb
25
+ adolph menzel
26
+ adriaen brouwer
27
+ adriaen coorte
28
+ adriaen hanneman
29
+ adriaen isenbrant
30
+ adriaen van de velde
31
+ adriaen van de venne
32
+ adriaen van der werff
33
+ adriaen van ostade
34
+ adrian ghenie
35
+ adrian zingg
36
+ aelbert cuyp
37
+ aert de gelder
38
+ aert van der neer
39
+ afewerk tekle
40
+ affandi
41
+ agnes martin
42
+ agnolo bronzino
43
+ agnolo gaddi
44
+ agostino carracci
45
+ ai weiwei
46
+ ai xuan
47
+ aimé barraud
48
+ akihiko yoshida
49
+ akira toriyama
50
+ al feldstein
51
+ al williamson
52
+ alan bean
53
+ alan davis
54
+ alan lee
55
+ alasdair gray
56
+ albert anker
57
+ albert aublet
58
+ albert bierstadt
59
+ albert dorne
60
+ albert edelfelt
61
+ albert gleizes
62
+ albert guillaume
63
+ albert joseph moore
64
+ albert marquet
65
+ albert namatjira
66
+ alberto giacometti
67
+ alberto morrocco
68
+ alberto seveso
69
+ alberto vargas
70
+ albrecht altdorfer
71
+ albrecht dürer
72
+ alejandro burdisio
73
+ aleksander gierymski
74
+ aleksander kobzdej
75
+ aleksandr gerasimov
76
+ aleksi briclot
77
+ alena aenami
78
+ alessandro allori
79
+ alesso baldovinetti
80
+ alex grey
81
+ alex horley
82
+ alex katz
83
+ alex ross
84
+ alex toth
85
+ alexander archipenko
86
+ alexander brook
87
+ alexander calder
88
+ alexander carse
89
+ alexander deyneka
90
+ alexander ivanov
91
+ alexander jansson
92
+ alexander johnston
93
+ alexander kanoldt
94
+ alexander kucharsky
95
+ alexander litovchenko
96
+ alexander mann
97
+ alexander mcqueen
98
+ alexander nasmyth
99
+ alexander robertson
100
+ alexander rodchenko
101
+ alexander roslin
102
+ alexander scott
103
+ alexander sharpe ross
104
+ alexander stirling calder
105
+ alexandre benois
106
+ alexandre cabanel
107
+ alexandre falguière
108
+ alexei kondratyevich savrasov
109
+ alexej von jawlensky
110
+ alexey venetsianov
111
+ alexis grimou
112
+ alexis simon belle
113
+ alfons walde
114
+ alfonse mucha
115
+ alfred east
116
+ alfred edward chalon
117
+ alfred eisenstaedt
118
+ alfred freddy krupa
119
+ alfred janes
120
+ alfred jensen
121
+ alfred kubin
122
+ alfred leslie
123
+ alfred leyman
124
+ alfred richard gurrey
125
+ alfred sisley
126
+ alfred thompson bricher
127
+ algernon talmage
128
+ alice bailly
129
+ alice mason
130
+ alice neel
131
+ alice prin
132
+ alison watt
133
+ allaert van everdingen
134
+ allan brooks
135
+ allan linder
136
+ allan ramsay
137
+ allen butler talcott
138
+ allen jones
139
+ allen tupper true
140
+ alma thomas
141
+ almada negreiros
142
+ almeida júnior
143
+ aloysius okelly
144
+ alphonse legros
145
+ alphonse mucha
146
+ alvan fisher
147
+ amadeo de souza cardoso
148
+ amalia lindegren
149
+ amanda sage
150
+ amano
151
+ ambrose mccarthy patterson
152
+ ambrosius benson
153
+ ambrosius bosschaert
154
+ ambrosius holbein
155
+ amedeo modigliani
156
+ americo makk
157
+ amir zand
158
+ ammi phillips
159
+ amos sewell
160
+ amy sol
161
+ amy weber
162
+ an gyeon
163
+ anato finnstark
164
+ anders zorn
165
+ andré charles biéler
166
+ andre derain
167
+ andré derain
168
+ andré françois
169
+ andré kertész
170
+ andré lhote
171
+ andré masson
172
+ andrea del sarto
173
+ andrea del verrocchio
174
+ andrea kowch
175
+ andrea mantegna
176
+ andrea orcagna
177
+ andrea pozzo
178
+ andreas achenbach
179
+ andreas gursky
180
+ andreas rocha
181
+ andrei riabovitchev
182
+ andrei rublev
183
+ andrei ryabushkin
184
+ andrew bell
185
+ andrew ferez
186
+ andrew geddes
187
+ andrew henderson
188
+ andrew law
189
+ andrew loomis
190
+ andrew robertson
191
+ andrew robinson
192
+ andrew stevovich
193
+ andrew wyeth
194
+ android jones
195
+ andrzej wróblewski
196
+ andy goldsworthy
197
+ andy warhol
198
+ ángel botello
199
+ angelica kauffman
200
+ aniello falcone
201
+ anish kapoor
202
+ anita kunz
203
+ anka zhuravleva
204
+ anna ancher
205
+ anna boch
206
+ anna dittmann
207
+ anna mary robertson moses
208
+ anne dunn
209
+ anne geddes
210
+ anne redpath
211
+ anne ryan
212
+ anne said
213
+ anne savage
214
+ anne stokes
215
+ anni albers
216
+ annibale carracci
217
+ annie leibovitz
218
+ annie rose laing
219
+ ansel adams
220
+ anselm kiefer
221
+ antanas sutkus
222
+ anthony devas
223
+ anthony palumbo
224
+ anthony van dyck
225
+ antoine blanchard
226
+ antoine ignace melling
227
+ antoine le nain
228
+ antoine wiertz
229
+ anton ažbe
230
+ anton fadeev
231
+ anton graff
232
+ anton mauve
233
+ anton otto fischer
234
+ anton pieck
235
+ anton räderscheidt
236
+ antonello da messina
237
+ antoni brodowski
238
+ antonin artaud
239
+ antonín chittussi
240
+ antonín slavíček
241
+ antonio canova
242
+ antonio cavallucci
243
+ antonio ciseri
244
+ antonio de la gandara
245
+ antonio donghi
246
+ antonio mancini
247
+ antônio parreiras
248
+ antonio rotta
249
+ apelles
250
+ apollinary vasnetsov
251
+ apollonia saintclair
252
+ archibald motley
253
+ archibald robertson
254
+ archibald skirving
255
+ archibald standish hartrick
256
+ arcimboldo
257
+ arie smit
258
+ aristide maillol
259
+ arkhip kuindzhi
260
+ armand guillaumin
261
+ armin hansen
262
+ arnold blanch
263
+ arnold böcklin
264
+ arnold mesches
265
+ arnold newman
266
+ arshile gorky
267
+ art fitzpatrick
268
+ art frahm
269
+ art green
270
+ art spiegelman
271
+ artemisia gentileschi
272
+ artgerm
273
+ arthur adams
274
+ arthur b carles
275
+ arthur boyd
276
+ arthur burdett frost
277
+ arthur dove
278
+ arthur hughes
279
+ arthur lismer
280
+ arthur melville
281
+ arthur pan
282
+ arthur quartley
283
+ arthur rackham
284
+ arthur sarkissian
285
+ arthur streeton
286
+ artur grottger
287
+ arvid nyholm
288
+ asaf hanuka
289
+ asai chū
290
+ asher brown durand
291
+ ashley wood
292
+ atey ghailan
293
+ attila meszlenyi
294
+ aubrey beardsley
295
+ audrey kawasaki
296
+ august friedrich schenck
297
+ august macke
298
+ august sander
299
+ auguste herbin
300
+ augustus earle
301
+ augustus john
302
+ augustus vincent tack
303
+ auseklis ozols
304
+ austin briggs
305
+ austin english
306
+ austin osman spare
307
+ ayo
308
+ ayami kojima
309
+ balthasar van der ast
310
+ balthus
311
+ banksy
312
+ barbara longhi
313
+ barclay shaw
314
+ barent fabritius
315
+ barkley hendricks
316
+ barnett newman
317
+ barron storey
318
+ bartholomeus breenbergh
319
+ bartholomeus strobel
320
+ bartholomeus van bassen
321
+ bartholomeus van der helst
322
+ bartolomé esteban murillo
323
+ bartolomeo vivarini
324
+ bascove
325
+ basil blackshaw
326
+ basuki abdullah
327
+ bayard wu
328
+ beatrice ethel lithiby
329
+ beatrix potter
330
+ beauford delaney
331
+ beeple
332
+ beksinski
333
+ ben enwonwu
334
+ ben nicholson
335
+ ben shahn
336
+ ben stahl
337
+ ben templesmith
338
+ ben thompson
339
+ benito quinquela martín
340
+ benjamin block
341
+ benjamin franklin
342
+ benjamin gerritsz cuyp
343
+ benjamin west
344
+ benjamin williams leader
345
+ benoit b mandelbrot
346
+ bernard buffet
347
+ bernard meninsky
348
+ bernard van orley
349
+ bernardo bellotto
350
+ bernardo cavallino
351
+ bernardo daddi
352
+ bernardo strozzi
353
+ bernie wrightson
354
+ bert hardy
355
+ 3d render
356
+ black white photo
357
+ bert stern
358
+ cartoon
359
+ berthe morisot
360
+ character portrait
361
+ charcoal drawing
362
+ bronze sculpture
363
+ bertalan székely
364
+ colorized photo
365
+ cave painting
366
+ color pencil sketch
367
+ cross stitch
368
+ cubist painting
369
+ detailed drawing
370
+ bhupen khakhar
371
+ detailed painting
372
+ diagram
373
+ digital painting
374
+ digital rendering
375
+ drawing
376
+ fine art painting
377
+ bill ward
378
+ gouache
379
+ billy childish
380
+ hyperrealistic painting
381
+ jigsaw puzzle
382
+ bob byerley
383
+ bob eggleton
384
+ bob ross
385
+ flemish baroque
386
+ matte painting
387
+ hologram
388
+ macro photograph
389
+ manga drawing
390
+ mosaic
391
+ low poly render
392
+ pastel
393
+ pencil sketch
394
+ boris kustodiev
395
+ microscopic photo
396
+ bourgeois
397
+ photorealistic painting
398
+ pointillism painting
399
+ photocopy
400
+ brad kunkle
401
+ pop art painting
402
+ brassaï
403
+ polaroid photo
404
+ renaissance painting
405
+ screenprint
406
+ screenshot
407
+ silk screen
408
+ sketch
409
+ statue
410
+ still life
411
+ stipple
412
+ brett whiteley
413
+ brian froud
414
+ surrealist painting
415
+ storybook illustration
416
+ tattoo
417
+ tilt shift photo
418
+ watercolor painting
419
+ surrealist sculpture
420
+ woodcut
421
+ abstract drawing
422
+ brom
423
+ abstract painting
424
+ acrylic painting
425
+ brothers hildebrandt
426
+ brooke shaden
427
+ album cover
428
+ bruce davidson
429
+ bruce gilden
430
+ airbrush painting
431
+ bruce pennington
432
+ etching
433
+ bruno liljefors
434
+ impressionist painting
435
+ ink drawing
436
+ oil canvas painting
437
+ buckminster fuller
438
+ art deco painting
439
+ chalk art
440
+ computer graphics
441
+ concept art
442
+ cyberpunk art
443
+ cagnaccio di san pietro
444
+ egyptian art
445
+ graffiti art
446
+ camille corot
447
+ lineart
448
+ camille pissarro
449
+ poster art
450
+ vector art
451
+ camille bombois
452
+ canaletto
453
+ camilo egas
454
+ camilo mori
455
+ candido portinari
456
+ cao zhibai
457
+ caravaggio
458
+ carel fabritius
459
+ carel weight
460
+ carel willink
461
+ carl barks
462
+ carl eytel
463
+ carl frederik von breda
464
+ carl gustaf pilo
465
+ carl heinrich bloch
466
+ carl hoppe
467
+ carl larsson
468
+ carl rahl
469
+ carl spitzweg
470
+ carl walter liner
471
+ carla wyzgala
472
+ carlo carrà
473
+ carlo crivelli
474
+ carlo galli bibiena
475
+ carlo mense
476
+ carlos schwabe
477
+ carne griffiths
478
+ carol bove
479
+ carol sutton
480
+ caroline lucy scott
481
+ caroline mytinger
482
+ carrie mae weems
483
+ casey baugh
484
+ caspar david friedrich
485
+ caspar netscher
486
+ caspar van wittel
487
+ caspar wolf
488
+ cassandra austen
489
+ cassius marcellus coolidge
490
+ catrin welzstein
491
+ cecil beaton
492
+ cecilia beaux
493
+ cecily brown
494
+ cedric peyravernay
495
+ cerith wyn evans
496
+ cézanne
497
+ chagall
498
+ chaim soutine
499
+ chaïm soutine
500
+ charles alston
501
+ charles angrand
502
+ charles bird king
503
+ charles blackman
504
+ charles codman
505
+ charles conder
506
+ charles cundall
507
+ charles dana gibson
508
+ charles demuth
509
+ charles e burchfield
510
+ charles furneaux
511
+ charles ginner
512
+ charles gleyre
513
+ charles h woodbury
514
+ charles harold davis
515
+ charles haslewood shannon
516
+ charles hopkinson
517
+ charles joshua chaplin
518
+ charles le brun
519
+ charles mahoney
520
+ charles marion russell
521
+ charles martin
522
+ charles mcauley
523
+ charles ragland bunnell
524
+ charles rennie mackintosh
525
+ charles ricketts
526
+ charles roka
527
+ charles schulz
528
+ charles thomson
529
+ charles vess
530
+ charles w bartlett
531
+ charles williams
532
+ charles willson peale
533
+ charlie bowater
534
+ charlotte harding
535
+ charlotte nasmyth
536
+ chase stone
537
+ chen chi
538
+ chen chun
539
+ chen daofu
540
+ chen hong
541
+ chen hongshou
542
+ chen lin
543
+ chen lu
544
+ chen yifei
545
+ cheng shifa
546
+ cheng zhengkui
547
+ chesley bonestell
548
+ chiharu shiota
549
+ childe hassam
550
+ chip zdarsky
551
+ chippy
552
+ choi buk
553
+ chris cold
554
+ chris foss
555
+ chris friel
556
+ chris labrooy
557
+ chris moore
558
+ chris rahn
559
+ chris rallis
560
+ chris ware
561
+ christen dalsgaard
562
+ christen købke
563
+ christian jane fergusson
564
+ christian krohg
565
+ christian rohlfs
566
+ christo
567
+ christoffer wilhelm eckersberg
568
+ christoph amberger
569
+ christoph ludwig agricola
570
+ christophe vacher
571
+ christopher balaskas
572
+ christopher moeller
573
+ christopher perkins
574
+ christopher williams
575
+ christopher wood
576
+ christopher wren
577
+ chuck close
578
+ cicely mary barker
579
+ cimabue
580
+ cindy sherman
581
+ cindy wright
582
+ claire dalby
583
+ claire hummel
584
+ clara miller burd
585
+ clara peeters
586
+ clara weaver parrish
587
+ clarence holbrook carter
588
+ clarice beckett
589
+ clark voorhees
590
+ claude cahun
591
+ claude lorrain
592
+ claude monet
593
+ cleon peterson
594
+ cleve gray
595
+ cliff childs
596
+ clifford ross
597
+ clint cearley
598
+ clyde caldwell
599
+ clyfford still
600
+ coby whitmore
601
+ coles phillips
602
+ colijn de coter
603
+ colin campbell cooper
604
+ colin gill
605
+ colin hayes
606
+ colin mccahon
607
+ colin middleton
608
+ colin moss
609
+ conrad roset
610
+ conroy maddox
611
+ constant
612
+ constant permeke
613
+ constantin hansen
614
+ corneille
615
+ cornelia parker
616
+ cornelis anthonisz
617
+ cornelis bisschop
618
+ cornelis de heem
619
+ cornelis de man
620
+ cornelis dusart
621
+ cornelis saftleven
622
+ cornelis van haarlem
623
+ correggio
624
+ cosmo alexander
625
+ craig davison
626
+ craig mullins
627
+ craig thompson
628
+ craola
629
+ cristofano allori
630
+ csaba markus
631
+ cuno amiet
632
+ cy twombly
633
+ cynthia sheppard
634
+ cyril rolando
635
+ d howard hitchcock
636
+ daarken
637
+ dai jin
638
+ dai xi
639
+ dali
640
+ dalí
641
+ damien hirst
642
+ dan frazier
643
+ dan hillier
644
+ dan luvisi
645
+ dan mumford
646
+ dan scott
647
+ dan smith
648
+ daniel f gerhartz
649
+ daniel garber
650
+ daniel lieske
651
+ daniel ljunggren
652
+ daniel maclise
653
+ daniel merriam
654
+ daniël mijtens
655
+ daniel taylor
656
+ dante gabriel rossetti
657
+ daphne fedarb
658
+ darek zabrocki
659
+ daren bader
660
+ dariusz zawadzki
661
+ dave arredondo
662
+ dave dorman
663
+ dave gibbons
664
+ dave kendall
665
+ dave mckean
666
+ david alfaro siqueiros
667
+ david allan
668
+ david annand
669
+ david bailly
670
+ david bomberg
671
+ david boyd
672
+ david brewster
673
+ david budd
674
+ david burliuk
675
+ david chipperfield
676
+ david diao
677
+ david donaldson
678
+ david eugene henry
679
+ david gilmour blythe
680
+ david hockney
681
+ david inshaw
682
+ david lachapelle
683
+ david ligare
684
+ david martin
685
+ david octavius hill
686
+ david palumbo
687
+ david park
688
+ david roberts
689
+ david simpson
690
+ david small
691
+ david teniers iii
692
+ david wilkie
693
+ david wojnarowicz
694
+ david young cameron
695
+ dean cornwell
696
+ dean ellis
697
+ dean roger
698
+ delaunay
699
+ delphin enjolras
700
+ dennis flanders
701
+ dennis miller bunker
702
+ derek gores
703
+ derek hill
704
+ derek jarman
705
+ derf
706
+ desmond morris
707
+ diane arbus
708
+ diane dillon
709
+ diego giacometti
710
+ diego gisbert llorens
711
+ diego rivera
712
+ diego velázquez
713
+ dieric bouts
714
+ ding guanpeng
715
+ ding yunpeng
716
+ dino valls
717
+ dionisio baixeras verdaguer
718
+ dirck de bray
719
+ dirck hals
720
+ dirck van baburen
721
+ dirck van delen
722
+ disney
723
+ ditlev blunck
724
+ dmitry levitzky
725
+ dod procter
726
+ domenichino
727
+ domenico di pace beccafumi
728
+ domenico ghirlandaio
729
+ domenico induno
730
+ domenico zampieri
731
+ don eddy
732
+ donald judd
733
+ donald roller wilson
734
+ donato giancola
735
+ dong kingman
736
+ dong qichang
737
+ dong yuan
738
+ dora carrington
739
+ dora maar
740
+ dorothea lange
741
+ dorothea tanning
742
+ dorothy burroughes
743
+ dorothy hood
744
+ dorothy johnstone
745
+ dorothy king
746
+ dosso dossi
747
+ douglas shuler
748
+ dr atl
749
+ dr seuss
750
+ drew struzan
751
+ drew tucker
752
+ du jin
753
+ duccio
754
+ dugald sutherland maccoll
755
+ abstract art
756
+ abstract expressionism
757
+ dürer
758
+ academic art
759
+ action painting
760
+ aestheticism
761
+ afrofuturism
762
+ duncan grant
763
+ dwight william tryon
764
+ american impressionism
765
+ american realism
766
+ american romanticism
767
+ american scene painting
768
+ earle bergey
769
+ e charlton fortune
770
+ arabesque
771
+ ed benedict
772
+ ed binkley
773
+ art brut
774
+ art deco
775
+ ed roth
776
+ art nouveau
777
+ art photography
778
+ eddie mendoza
779
+ edgar degas
780
+ arts crafts movement
781
+ ashcan school
782
+ assemblage
783
+ eddie campbell
784
+ edith lawrence
785
+ barbizon school
786
+ baroque
787
+ bauhaus
788
+ edmund blampied
789
+ edmund charles tarbell
790
+ edmund dulac
791
+ brutalism
792
+ classical realism
793
+ edmund leighton
794
+ cobra
795
+ color field
796
+ computer art
797
+ conceptual art
798
+ édouard manet
799
+ constructivism
800
+ concrete art
801
+ crayon art
802
+ eduardo kingman
803
+ cubism
804
+ eduard von grützner
805
+ edvard munch
806
+ dada
807
+ edward armitage
808
+ edward atkinson hornel
809
+ de stijl
810
+ edward arthur walton
811
+ digital art
812
+ deconstructivism
813
+ environmental art
814
+ edward clark
815
+ expressionism
816
+ fantastic realism
817
+ fantasy art
818
+ fauvism
819
+ edward henry potthast
820
+ edward gorey
821
+ edward hopper
822
+ figurative art
823
+ fine art
824
+ edward lamson henry
825
+ folk art
826
+ edward lear
827
+ edward mitchell bannister
828
+ futurism
829
+ furry art
830
+ edward robert hughes
831
+ figurativism
832
+ edward simmons
833
+ graffiti
834
+ gothic art
835
+ edward weston
836
+ happening
837
+ harlem renaissance
838
+ edwin deakin
839
+ holography
840
+ edwin austin abbey
841
+ edward willis redfield
842
+ hyperrealism
843
+ hudson river school
844
+ edwin georgi
845
+ edwin landseer
846
+ impressionism
847
+ eero järnefelt
848
+ egon schiele
849
+ egbert van der poel
850
+ eiq
851
+ interactive art
852
+ land art
853
+ kinetic art
854
+ les nabis
855
+ egbert van heemskerck
856
+ light space
857
+ lowbrow
858
+ ejnar nielsen
859
+ el greco
860
+ magic realism
861
+ magical realism
862
+ mail art
863
+ mannerism
864
+ el lissitzky
865
+ maximalism
866
+ metaphysical painting
867
+ lyrical abstraction
868
+ minimalism
869
+ elaine de kooning
870
+ modernism
871
+ eleanor fortescuebrickdale
872
+ naive art
873
+ naturalism
874
+ mingei
875
+ eleanor vere boyle
876
+ eliot hodgkin
877
+ élisabeth vigée le brun
878
+ eliseu visconti
879
+ neoclassicism
880
+ neogeo
881
+ elizabeth forbes
882
+ elizabeth jane lloyd
883
+ net art
884
+ new objectivity
885
+ elizabeth murray
886
+ elizabeth shippen green
887
+ new sculpture
888
+ elke vogelsang
889
+ op art
890
+ optical illusion
891
+ elliott erwitt
892
+ orphism
893
+ elmer bischoff
894
+ photorealism
895
+ pixel art
896
+ ellsworth kelly
897
+ plein air
898
+ pointillism
899
+ pop art
900
+ pop surrealism
901
+ elsa beskow
902
+ postimpressionism
903
+ elmyr de hory
904
+ precisionism
905
+ emanuel leutze
906
+ emanuel de witte
907
+ process art
908
+ psychedelic art
909
+ emil bisttram
910
+ emil carlsen
911
+ primitivism
912
+ emil fuchs
913
+ emil nolde
914
+ realism
915
+ regionalism
916
+ émile bernard
917
+ renaissance
918
+ retrofuturism
919
+ rococo
920
+ romanesque
921
+ emily carr
922
+ romanticism
923
+ emiliano ponzi
924
+ shin hanga
925
+ emiliano di cavalcanti
926
+ socialist realism
927
+ emily shanks
928
+ space art
929
+ street art
930
+ emory douglas
931
+ emma lampert cooper
932
+ superflat
933
+ suprematism
934
+ surrealism
935
+ symbolism
936
+ enrique simonet
937
+ enrique grau
938
+ enki bilal
939
+ tachisme
940
+ temporary art
941
+ tonalism
942
+ eric auld
943
+ eric deschamps
944
+ ukiyoe
945
+ eric peterson
946
+ eric taylor
947
+ eric zener
948
+ vanitas
949
+ erich heckel
950
+ video art
951
+ erin hanson
952
+ visual art
953
+ ernest biéler
954
+ underground comix
955
+ ernest buckmaster
956
+ ernest hébert
957
+ ernest lawson
958
+ ernest morgan
959
+ ernest procter
960
+ ernest william christmas
961
+ ernie barnes
962
+ ernst
963
+ ernst fuchs
964
+ ernst haeckel
965
+ ernst ludwig kirchner
966
+ ernst thoms
967
+ ernst wilhelm nay
968
+ erwin bowien
969
+ esaias van de velde
970
+ esao
971
+ esao andrews
972
+ esteban vicente
973
+ etienne delessert
974
+ ettore tito
975
+ euan uglow
976
+ eugène boudin
977
+ eugène burnand
978
+ eugène carrière
979
+ eugene delacroix
980
+ eugène delacroix
981
+ eugène grasset
982
+ eugène isabey
983
+ childs drawing
984
+ eugene von guerard
985
+ eugeniusz zak
986
+ eva gonzalès
987
+ évariste vital luminais
988
+ evaristo baschenis
989
+ evelyn abelson
990
+ evelyn de morgan
991
+ everett raymond kinstler
992
+ everett shinn
993
+ computer rendering
994
+ evert collier
995
+ evgeny lushpin
996
+ eyvind earle
997
+ f scott hess
998
+ fabien charuau
999
+ fairfield porter
1000
+ fan kuan
1001
+ fan qi
1002
+ fang congyi
1003
+ farel dalrymple
1004
+ fede galizia
1005
+ federico barocci
1006
+ federico uribe
1007
+ federico zandomeneghi
1008
+ federico zuccari
1009
+ fedot sychkov
1010
+ detailed matte painting
1011
+ felice casorati
1012
+ felicity charlton
1013
+ fei danxu
1014
+ félix vallotton
1015
+ félix ziem
1016
+ feng zhu
1017
+ fenghua zhong
1018
+ ferdinand bol
1019
+ ferdinand hodler
1020
+ ferdinand knab
1021
+ ferdynand ruszczyc
1022
+ fern coppedge
1023
+ fernand léger
1024
+ fernand pelez
1025
+ fernand toussaint
1026
+ fernando amorsolo
1027
+ fernando botero
1028
+ filip hodas
1029
+ filippino lippi
1030
+ fiona stephenson
1031
+ fitz henry lane
1032
+ fitz hugh lane
1033
+ fletcher martin
1034
+ flora macdonald reid
1035
+ floris van dyck
1036
+ floris van schooten
1037
+ ford madox brown
1038
+ fra angelico
1039
+ fra bartolomeo
1040
+ fra filippo lippi
1041
+ frances c fairman
1042
+ frances hodgkins
1043
+ frances macdonald
1044
+ francesco albani
1045
+ francesco bartolozzi
1046
+ francesco bonsignori
1047
+ francesco clemente
1048
+ francesco del cossa
1049
+ francesco filippini
1050
+ francesco guardi
1051
+ francesco hayez
1052
+ francesco raibolini
1053
+ francis bacon
1054
+ francis bourgeois
1055
+ francis cadell
1056
+ francis davis millet
1057
+ francis ernest jackson
1058
+ francis focer brown
1059
+ francis helps
1060
+ francis picabia
1061
+ marble sculpture
1062
+ francisco de holanda
1063
+ francisco de zurbarán
1064
+ francisco goya
1065
+ francisco oller
1066
+ francisco zúñiga
1067
+ franciszek smuglewicz
1068
+ françois barraud
1069
+ françois bocion
1070
+ françois boucher
1071
+ françois clouet
1072
+ françois joseph heim
1073
+ françois quesnel
1074
+ frank auerbach
1075
+ minimalist painting
1076
+ frank buchser
1077
+ frank dumond
1078
+ frank frazetta
1079
+ frank leonard brooks
1080
+ frank mason
1081
+ frank mckelvey
1082
+ frank miller
1083
+ frank montague moore
1084
+ frank omeara
1085
+ frank schoonover
1086
+ frank stella
1087
+ frank weston benson
1088
+ frank xavier leyendecker
1089
+ franklin booth
1090
+ franklin carmichael
1091
+ frans hals
1092
+ frans koppelaar
1093
+ frans masereel
1094
+ františek kaván
1095
+ františek kupka
1096
+ franz kline
1097
+ franz marc
1098
+ franz sedlacek
1099
+ franz stuck
1100
+ franz vohwinkel
1101
+ franz von lenbach
1102
+ franz xaver winterhalter
1103
+ fred cress
1104
+ fred ludekens
1105
+ fred mitchell
1106
+ fred williams
1107
+ frédéric bazille
1108
+ frederic church
1109
+ frederic edwin church
1110
+ frederic leighton
1111
+ frederic remington
1112
+ frederick carl frieseke
1113
+ frederick edwin church
1114
+ frederick goodall
1115
+ frederick hammersley
1116
+ frederick lord leighton
1117
+ frederick mccubbin
1118
+ frederik de moucheron
1119
+ frederik vermehren
1120
+ frida kahlo
1121
+ friedel dzubas
1122
+ friedensreich hundertwasser
1123
+ friedrich gauermann
1124
+ friedrich von amerling
1125
+ frieke janssens
1126
+ frits thaulow
1127
+ fritz von dardel
1128
+ fritz von uhde
1129
+ fu baoshi
1130
+ fujishima takeji
1131
+ fyodor alekseyev
1132
+ fyodor rokotov
1133
+ fyodor vasilyev
1134
+ gabriel ba
1135
+ gabriel dawe
1136
+ gabriel metsu
1137
+ gabriele münter
1138
+ gaetano previati
1139
+ gai qi
1140
+ galen dara
1141
+ gao cen
1142
+ gao fenghan
1143
+ garry winogrand
1144
+ gaston anglade
1145
+ gaston bussiere
1146
+ gaston bussière
1147
+ gaudi
1148
+ gaugin
1149
+ gavin hamilton
1150
+ gawen hamilton
1151
+ gediminas pranckevicius
1152
+ geertgen tot sint jans
1153
+ gen paul
1154
+ gene davis
1155
+ gentile bellini
1156
+ geof darrow
1157
+ geoffrey dyer
1158
+ georg baselitz
1159
+ georg friedrich kersting
1160
+ georg friedrich schmidt
1161
+ georg muche
1162
+ georg scholz
1163
+ georg schrimpf
1164
+ george abe
1165
+ george ault
1166
+ george bain
1167
+ george barbier
1168
+ george barker
1169
+ george barret sr
1170
+ george bell
1171
+ george bellows
1172
+ george benjamin luks
1173
+ george biddle
1174
+ george caleb bingham
1175
+ george catlin
1176
+ george cruikshank
1177
+ george fiddes watt
1178
+ george frederic watts
1179
+ george frederick harris
1180
+ george gardner symons
1181
+ george grosz
1182
+ george hendrik breitner
1183
+ george henry
1184
+ george hurrell
1185
+ george inness
1186
+ george jamesone
1187
+ george lucas
1188
+ george luks
1189
+ george morrison
1190
+ george paul chalmers
1191
+ george pirie
1192
+ george reid
1193
+ george romney
1194
+ george stubbs
1195
+ george tooker
1196
+ abstract sculpture
1197
+ georges braque
1198
+ georges de la tour
1199
+ georges lacombe
1200
+ georges lemmen
1201
+ georges rouault
1202
+ georges seurat
1203
+ georges stein
1204
+ georgia okeeffe
1205
+ gerald brom
1206
+ gerald kelly
1207
+ gerard david
1208
+ gerard de lairesse
1209
+ gerard houckgeest
1210
+ gerard seghers
1211
+ gerard sekoto
1212
+ anime drawing
1213
+ gerard ter borch
1214
+ gerard soest
1215
+ gerda wegener
1216
+ gerhard richter
1217
+ gerbrand van den eeckhout
1218
+ germaine krull
1219
+ gerrit adriaenszoon berckheyde
1220
+ gerrit dou
1221
+ gertrude abercrombie
1222
+ art deco sculpture
1223
+ gertrude harvey
1224
+ géza dósa
1225
+ géza udvary
1226
+ engraving
1227
+ giacomo balla
1228
+ gian lorenzo bernini
1229
+ giger
1230
+ gil elvgren
1231
+ gilbert stuart
1232
+ gilles beloeil
1233
+ gillis rombouts
1234
+ gino severini
1235
+ giorgio de chirico
1236
+ giorgio morandi
1237
+ giorgione
1238
+ giotto
1239
+ giovanni antonio galli
1240
+ giovanni battista cipriani
1241
+ giovanni battista gaulli
1242
+ giovanni battista piazzetta
1243
+ giovanni battista piranesi
1244
+ giovanni battista tiepolo
1245
+ giovanni bellini
1246
+ giovanni bernardino azzolini
1247
+ giovanni boldini
1248
+ giovanni fattori
1249
+ giovanni francesco barbieri
1250
+ giovanni giacometti
1251
+ giovanni lanfranco
1252
+ ultrafine detailed painting
1253
+ giovanni paolo pannini
1254
+ giuseppe abbati
1255
+ giuseppe antonio petrini
1256
+ giuseppe arcimboldo
1257
+ giuseppe bernardino bison
1258
+ giuseppe camuncoli
1259
+ giuseppe de nittis
1260
+ giuseppe grisoni
1261
+ giuseppe tominz
1262
+ glen angus
1263
+ glen keane
1264
+ glenn fabry
1265
+ glennray tutor
1266
+ gloria stoll karn
1267
+ godfried schalcken
1268
+ gong xian
1269
+ gordon parks
1270
+ goro fujita
1271
+ gottfried helnwein
1272
+ govert dircksz camphuysen
1273
+ govert flinck
1274
+ goyō hashiguchi
1275
+ grace cossington smith
1276
+ grace english
1277
+ graham forsythe
1278
+ graham sutherland
1279
+ grandma moses
1280
+ grant wood
1281
+ grayson perry
1282
+ greg hildebrandt
1283
+ greg rutkowski
1284
+ greg spalenka
1285
+ greg staples
1286
+ gregory crewdson
1287
+ gregory gillespie
1288
+ gregory manchess
1289
+ grete stern
1290
+ grigoriy myasoyedov
1291
+ grzegorz rutkowski
1292
+ gu an
1293
+ gu hongzhong
1294
+ guan daosheng
1295
+ guido borelli da caluso
1296
+ guido reni
1297
+ guillermo del toro
1298
+ guo xi
1299
+ gustaf tenggren
1300
+ gustav dore
1301
+ gustav doré
1302
+ gustav klimt
1303
+ gustave baumann
1304
+ gustave boulanger
1305
+ gustave caillebotte
1306
+ gustave courbet
1307
+ gustave dore
1308
+ gustave doré
1309
+ gustave moreau
1310
+ gustave van de woestijne
1311
+ guy denning
1312
+ guy rose
1313
+ gwen john
1314
+ gwenny griffiths
1315
+ gwilym prichard
1316
+ gyula aggházy
1317
+ gyula batthyány
1318
+ gyula benczúr
1319
+ gyula derkovits
1320
+ h r giger
1321
+ hp lovecraft
1322
+ haddon sundblom
1323
+ hajime sorayama
1324
+ hal foster
1325
+ hamilton sloan
1326
+ hamish macdonald
1327
+ han gan
1328
+ hannabarbera
1329
+ hannah frank
1330
+ hanns katz
1331
+ hans asper
1332
+ hans baldung
1333
+ hans baluschek
1334
+ hans bellmer
1335
+ hans bol
1336
+ hans burgkmair
1337
+ hans erni
1338
+ hans fischer
1339
+ hans gude
1340
+ hans hofmann
1341
+ hans makart
1342
+ hans memling
1343
+ hans mertens
1344
+ hans von aachen
1345
+ hans von bartels
1346
+ harald giersing
1347
+ harold gilman
1348
+ harold harvey
1349
+ harold sandys williamson
1350
+ harold von schmidt
1351
+ harriet backer
1352
+ harrington mann
1353
+ harrison fisher
1354
+ harry clarke
1355
+ harry morley
1356
+ harumi hironaka
1357
+ harvey dunn
1358
+ harvey kurtzman
1359
+ harvey pratt
1360
+ hasegawa tōhaku
1361
+ hasui kawase
1362
+ hayao miyazaki
1363
+ heather hudson
1364
+ hedda sterne
1365
+ heinrich hofmann
1366
+ heinrich kley
1367
+ heinrich lefler
1368
+ heinrich maria davringhausen
1369
+ heinz anger
1370
+ helen edwards
1371
+ helen frankenthaler
1372
+ helen huang
1373
+ helene schjerfbeck
1374
+ helmut newton
1375
+ hendrick avercamp
1376
+ hendrick bloemaert
1377
+ hendrick terbrugghen
1378
+ hendrick van balen
1379
+ hendrick van streeck
1380
+ hendrik goltzius
1381
+ hendrik martenszoon sorgh
1382
+ hendrik van steenwijk i
1383
+ hendrik van steenwijk ii
1384
+ hendrik willem mesdag
1385
+ henri alphonse barnoin
1386
+ henri biva
1387
+ henri cartierbresson
1388
+ henri harpignies
1389
+ henri le sidaner
1390
+ henri matisse
1391
+ henri rousseau
1392
+ henriette wyeth
1393
+ henrik weber
1394
+ henry bright
1395
+ henry carr
1396
+ henry fuseli
1397
+ henry heerup
1398
+ henry justice ford
1399
+ henry lamb
1400
+ henry moore
1401
+ henry ossawa tanner
1402
+ henry otto wix
1403
+ henry raeburn
1404
+ henry raleigh
1405
+ henry scott tuke
1406
+ henry tonks
1407
+ henry van de velde
1408
+ henry wallis
1409
+ henry woods
1410
+ henryk siemiradzki
1411
+ herb ritts
1412
+ herbert bayer
1413
+ herbert james gunn
1414
+ herman saftleven
1415
+ herman van swanevelt
1416
+ hermenegildo anglada camarasa
1417
+ hieronymous bosch
1418
+ hieronymus bosch
1419
+ hikari shimoda
1420
+ hilma af klint
1421
+ hiromu arakawa
1422
+ hiroshi nagai
1423
+ hiroshi yoshida
1424
+ hiroshige
1425
+ hishikawa moronobu
1426
+ hisui sugiura
1427
+ hokusai
1428
+ holger roed
1429
+ honoré daumier
1430
+ horace vernet
1431
+ horatio mcculloch
1432
+ horatio nelson poole
1433
+ hovsep pushman
1434
+ howard butterworth
1435
+ howard chandler christy
1436
+ howard chaykin
1437
+ howard finster
1438
+ howard lyon
1439
+ howard pyle
1440
+ hr giger
1441
+ hu jieqing
1442
+ hua yan
1443
+ huang binhong
1444
+ huang ding
1445
+ huang gongwang
1446
+ huang guangjian
1447
+ huang ji
1448
+ huang shen
1449
+ huang tingjian
1450
+ hubert robert
1451
+ hubert van eyck
1452
+ hubert von herkomer
1453
+ hugh ferriss
1454
+ hugh william williams
1455
+ hugo anton fisher
1456
+ hugo heyrman
1457
+ hugo scheiber
1458
+ hugo simberg
1459
+ hugo van der goes
1460
+ humberto castro
1461
+ hundertwasser
1462
+ hyacinthe rigaud
1463
+ ian mcque
1464
+ ian miller
1465
+ ian spriggs
1466
+ ida rentoul outhwaite
1467
+ ignacio zuloaga
1468
+ ignacy witkiewicz
1469
+ ignat bednarik
1470
+ igor grabar
1471
+ igor kieryluk
1472
+ igor morski
1473
+ igor zenin
1474
+ ikuo hirayama
1475
+ illarion pryanishnikov
1476
+ ilya glazunov
1477
+ ilya kuvshinov
1478
+ ilya ostroukhov
1479
+ ilya repin
1480
+ ilya yefimovich repin
1481
+ ina wong
1482
+ ino
1483
+ ion andreescu
1484
+ irakli nadar
1485
+ irma stern
1486
+ isaac grünewald
1487
+ isaac levitan
1488
+ isaac soyer
1489
+ isabel codrington
1490
+ isabel naftel
1491
+ isamu noguchi
1492
+ isidor kaufman
1493
+ ismail acar
1494
+ ismail gulgee
1495
+ ismail inceoglu
1496
+ israel tsvaygenbaum
1497
+ istván csók
1498
+ istván orosz
1499
+ istván réti
1500
+ itō jakuchū
1501
+ itō shinsui
1502
+ itshak holtz
1503
+ ivan aivazovsky
1504
+ ivan albright
1505
+ ivan bilibin
1506
+ ivan generalić
1507
+ ivan kramskoi
1508
+ ivan mrkvička
1509
+ ivan shishkin
1510
+ ivan trush
1511
+ ivana kobilca
1512
+ ivor davies
1513
+ ivor williams
1514
+ iwasa matabei
1515
+ j alden weir
1516
+ j c leyendecker
1517
+ j frederick smith
1518
+ j l lund
1519
+ j m w turner
1520
+ j ottis adams
1521
+ jc leyendecker
1522
+ jmw turner
1523
+ jacek malczewski
1524
+ jacek yerka
1525
+ jack boul
1526
+ jack butler yeats
1527
+ jack davis
1528
+ jack kirby
1529
+ jack levine
1530
+ jack roth
1531
+ jack smith
1532
+ jackson pollock
1533
+ jacob adriaensz backer
1534
+ jacob burck
1535
+ jacob collins
1536
+ jacob de heusch
1537
+ jacob gerritsz cuyp
1538
+ jacob jordaens
1539
+ jacob kainen
1540
+ jacob koninck
1541
+ jacob lawrence
1542
+ jacob maris
1543
+ jacob more
1544
+ jacob ochtervelt
1545
+ jacob philipp hackert
1546
+ jacob pynas
1547
+ jacob savery
1548
+ jacob toorenvliet
1549
+ jacob van campen
1550
+ jacob van der ulft
1551
+ jacob van ruisdael
1552
+ jacopo amigoni
1553
+ jacopo bassano
1554
+ jacopo bellini
1555
+ jacopo de barbari
1556
+ jacopo pontormo
1557
+ jacques blanchard
1558
+ jacques callot
1559
+ jacques daret
1560
+ jacques sablet
1561
+ jacques villon
1562
+ jacqueslouis david
1563
+ jaime colson
1564
+ jaime jones
1565
+ jakob gauermann
1566
+ jakub rozalski
1567
+ jakub różalski
1568
+ jakub schikaneder
1569
+ james abbott mcneill whistler
1570
+ james barry
1571
+ james bateman
1572
+ james baynes
1573
+ james bolivar manson
1574
+ james c christensen
1575
+ james cadenhead
1576
+ james campbell noble
1577
+ james christensen
1578
+ james cowie
1579
+ james cromar watt
1580
+ james dickson innes
1581
+ james ensor
1582
+ james giles
1583
+ james gilleard
1584
+ james gillick
1585
+ james gillray
1586
+ james gurney
1587
+ james guthrie
1588
+ james humbert craig
1589
+ james jean
1590
+ james mcbey
1591
+ james mcintosh patrick
1592
+ james mcneill whistler
1593
+ james montgomery flagg
1594
+ james morris
1595
+ james morrison
1596
+ james paick
1597
+ james paterson
1598
+ james peale
1599
+ james pittendrigh macgillivray
1600
+ james rosenquist
1601
+ james ryman
1602
+ james thomas watts
1603
+ james tissot
1604
+ james warhola
1605
+ james wood
1606
+ jamie hewlett
1607
+ jamie wyeth
1608
+ jan antonisz van ravesteyn
1609
+ jan asselijn
1610
+ jan baptist weenix
1611
+ jan brett
1612
+ jan cornelisz vermeyen
1613
+ jan cox
1614
+ jan davidsz de heem
1615
+ jan de baen
1616
+ jan de bray
1617
+ jan gossaert
1618
+ jan griffier
1619
+ jan hackaert
1620
+ jan kip
1621
+ jan lievens
1622
+ jan matejko
1623
+ jan miel
1624
+ jan miense molenaer
1625
+ jan steen
1626
+ jan toorop
1627
+ jan van bijlert
1628
+ jan van de cappelle
1629
+ jan van der heyden
1630
+ jan van eyck
1631
+ jan van goyen
1632
+ jan van huysum
1633
+ jan van mieris
1634
+ jan verkolje
1635
+ jan victors
1636
+ jan wijnants
1637
+ jan wyck
1638
+ jan zrzavý
1639
+ jane carpanini
1640
+ jane frank
1641
+ jane freeman
1642
+ jane freilicher
1643
+ jane hawkins
1644
+ jane kelly
1645
+ jane nasmyth
1646
+ jane small
1647
+ janet archer
1648
+ janet dawson
1649
+ janet fish
1650
+ jános vaszary
1651
+ january suchodolski
1652
+ jarosław jaśnikowski
1653
+ jason benjamin
1654
+ jason chan
1655
+ jason edmiston
1656
+ jason felix
1657
+ jasper francis cropsey
1658
+ jasper johns
1659
+ jean antoine watteau
1660
+ jean arp
1661
+ jean auguste dominique ingres
1662
+ jean baptiste debret
1663
+ jean béraud
1664
+ jean clark
1665
+ jean colombe
1666
+ jean delville
1667
+ jean dubuffet
1668
+ jean dufy
1669
+ jean fouquet
1670
+ jean giraud
1671
+ jean hélion
1672
+ jean hey
1673
+ jean jouvenet
1674
+ jean metzinger
1675
+ jean micheal basquiat
1676
+ jean moebius giraud
1677
+ jean petitot
1678
+ jeanaugustedominique ingres
1679
+ jeanlouisernest meissonier
1680
+ jeanmarc nattier
1681
+ jeanmichel basquiat
1682
+ jeanna bauck
1683
+ jeanne hébuterne
1684
+ jeff easley
1685
+ jeff koons
1686
+ jeff miracola
1687
+ jeffrey catherine jones
1688
+ jeffrey smith
1689
+ jennifer janesko
1690
+ jenny eakin delony
1691
+ jenny saville
1692
+ jenő barcsay
1693
+ jens ferdinand willumsen
1694
+ jens juel
1695
+ jeong seon
1696
+ jeremiah ketner
1697
+ jeremy chong
1698
+ jeremy geddes
1699
+ jerry pinkney
1700
+ jerry schatzberg
1701
+ jerry weiss
1702
+ jerzy kossak
1703
+ jesper ejsing
1704
+ jesper myrfors
1705
+ jesse richards
1706
+ jessica rossier
1707
+ jessie willcox smith
1708
+ jiao bingzhen
1709
+ jim burns
1710
+ jim davis
1711
+ jim dine
1712
+ jim lee
1713
+ jim murray
1714
+ jim nelson
1715
+ jin nong
1716
+ jiro yoshihara
1717
+ joachim patinir
1718
+ joan brown
1719
+ joan miro
1720
+ joan miró
1721
+ joan snyder
1722
+ joanna carrington
1723
+ joaquín clausell
1724
+ joaquín sorolla
1725
+ jodorowsky
1726
+ joe bowler
1727
+ joe de mers
1728
+ joe fenton
1729
+ joe jusko
1730
+ joe machine
1731
+ joe mangrum
1732
+ joe shuster
1733
+ johan christian dahl
1734
+ johan jongkind
1735
+ johann berthelsen
1736
+ johann bodin
1737
+ johann christian brand
1738
+ johann friedrich overbeck
1739
+ johann gottfried steffan
1740
+ johann heinrich bleuler
1741
+ johann heinrich meyer
1742
+ johann jakob biedermann
1743
+ johann ludwig bleuler
1744
+ johann zoffany
1745
+ johannes cornelisz verspronck
1746
+ johannes helgeson
1747
+ johannes itten
1748
+ johannes lingelbach
1749
+ johannes mytens
1750
+ johannes vermeer
1751
+ johannes voss
1752
+ johfra bosschart
1753
+ john alexander
1754
+ john anster fitzgerald
1755
+ john armstrong
1756
+ john atherton
1757
+ john atkinson grimshaw
1758
+ john avon
1759
+ john bauer
1760
+ john bellany
1761
+ john berkey
1762
+ john blair
1763
+ john blanche
1764
+ john brack
1765
+ john brown
1766
+ john brown abercromby
1767
+ john button
1768
+ john byrne
1769
+ john cale
1770
+ john carpenter
1771
+ john clayton
1772
+ john clayton adams
1773
+ john collier
1774
+ john constable
1775
+ john duncan fergusson
1776
+ john e berninger
1777
+ john elwood bundy
1778
+ john everett millais
1779
+ john eyre
1780
+ john f francis
1781
+ john f peto
1782
+ john fabian carlson
1783
+ john frederick herring jr
1784
+ john frederick herring sr
1785
+ john frederick kensett
1786
+ john french sloan
1787
+ john fulton folinsbee
1788
+ john george sowerby
1789
+ john gibson
1790
+ john harris
1791
+ john henderson
1792
+ john henry lorimer
1793
+ john henry twachtman
1794
+ john howe
1795
+ john hutton
1796
+ john j park
1797
+ john james audubon
1798
+ john kay
1799
+ john keane
1800
+ john la gatta
1801
+ john lavery
1802
+ john linnell
1803
+ john lowrie morrison
1804
+ john luke
1805
+ john macdonald aiken
1806
+ john marin
1807
+ john martin
1808
+ john maxwell
1809
+ john mclaughlin
1810
+ john michael wright
1811
+ john murdoch
1812
+ john noble barlow
1813
+ john opie
1814
+ john parker
1815
+ john pettie
1816
+ john philip falter
1817
+ john platt
1818
+ john plumb
1819
+ john quinton pringle
1820
+ john robertson reid
1821
+ john romita jr
1822
+ john salminen
1823
+ john singer sargent
1824
+ john singleton copley
1825
+ john skinner prout
1826
+ john sloan
1827
+ john souch
1828
+ john steell
1829
+ john steuart curry
1830
+ john stuart ingle
1831
+ john trumbull
1832
+ john watson gordon
1833
+ john william godward
1834
+ john william waterhouse
1835
+ john wilson
1836
+ john wollaston
1837
+ john wonnacott
1838
+ jon foster
1839
+ jon whitcomb
1840
+ jonas de ro
1841
+ jonathan solter
1842
+ joos de momper
1843
+ jordan grimmer
1844
+ jorge jacinto
1845
+ jørgen roed
1846
+ josan gonzalez
1847
+ josé clemente orozco
1848
+ josé malhoa
1849
+ josef abel
1850
+ josef albers
1851
+ josef mánes
1852
+ josep rovira soler
1853
+ joseph badger
1854
+ joseph beuys
1855
+ joseph bowler
1856
+ joseph christian leyendecker
1857
+ joseph cornell
1858
+ joseph decamp
1859
+ joseph delaney
1860
+ joseph ducreux
1861
+ joseph dwight strong
1862
+ joseph henderson
1863
+ joseph kleitsch
1864
+ joseph noel paton
1865
+ joseph raphael
1866
+ joseph severn
1867
+ joseph stella
1868
+ joseph von führich
1869
+ joseph werner
1870
+ joseph wright of derby
1871
+ analytical art
1872
+ antipodeans
1873
+ josephine wall
1874
+ josetsu
1875
+ joshua reynolds
1876
+ josse lieferinxe
1877
+ jozef israëls
1878
+ józef mehoffer
1879
+ józef pankiewicz
1880
+ jozef simmler
1881
+ art language
1882
+ józsef borsos
1883
+ ju chao
1884
+ ju lian
1885
+ juan de flandes
1886
+ juan giménez
1887
+ juan gris
1888
+ juan luna
1889
+ juan ogorman
1890
+ judith brown
1891
+ judith leyster
1892
+ judy cassab
1893
+ judy takács
1894
+ jules bastienlepage
1895
+ jules breton
1896
+ jules chéret
1897
+ jules joseph lefebvre
1898
+ jules pascin
1899
+ arte povera
1900
+ jules tavernier
1901
+ julia margaret cameron
1902
+ julian fałat
1903
+ julian onderdonk
1904
+ julian schnabel
1905
+ julie bell
1906
+ ascii art
1907
+ julio gonzález
1908
+ julio larraz
1909
+ julius exner
1910
+ julius leblanc stewart
1911
+ juliusz kossak
1912
+ jung park
1913
+ junji ito
1914
+ justin currie
1915
+ justin gerard
1916
+ justin sweet
1917
+ justus van gent
1918
+ kaburagi kiyokata
1919
+ kadir nelson
1920
+ kahlo
1921
+ kaigetsudō ando
1922
+ kaii higashiyama
1923
+ kalervo palsa
1924
+ kamisaka sekka
1925
+ kandinsky
1926
+ kanō eitoku
1927
+ kanō hōgai
1928
+ bengal school art
1929
+ kanō motonobu
1930
+ berlin secession
1931
+ kanō sansetsu
1932
+ kanō sanraku
1933
+ kanō tanyū
1934
+ black arts movement
1935
+ kanzan shimomura
1936
+ karel dujardin
1937
+ bertalan karlovszky
1938
+ karel van mander
1939
+ karl bodmer
1940
+ karl bryullov
1941
+ karl hagedorn
1942
+ cloisonnism
1943
+ karl hofer
1944
+ karl kopinski
1945
+ bertram brooker
1946
+ karol bak
1947
+ karolis strautniekas
1948
+ károly brocky
1949
+ károly ferenczy
1950
+ károly kernstok
1951
+ károly kisfaludy
1952
+ károly lotz
1953
+ károly patkó
1954
+ kate beaton
1955
+ kate greenaway
1956
+ käthe kollwitz
1957
+ kathleen scott
1958
+ kati horna
1959
+ katia chausheva
1960
+ katsukawa shunei
1961
+ context art
1962
+ katsukawa shunsen
1963
+ katsukawa shunshō
1964
+ katsushika hokusai
1965
+ betye saar
1966
+ katsuya terada
1967
+ kawai gyokudō
1968
+ kawanabe kyōsai
1969
+ kawase hasui
1970
+ kay nielsen
1971
+ kay sage
1972
+ crystal cubism
1973
+ kazimierz alchimowicz
1974
+ kazimir malevich
1975
+ kees bol
1976
+ kees maks
1977
+ kees scherer
1978
+ kees van dongen
1979
+ keisai eisen
1980
+ keith haring
1981
+ keith henderson
1982
+ keith mallett
1983
+ keith parkinson
1984
+ cynical realism
1985
+ kelly mckernan
1986
+ kelly freas
1987
+ ken danby
1988
+ bikash bhattacharjee
1989
+ ken howard
1990
+ ken kelly
1991
+ bill lewis
1992
+ kelly sueda
1993
+ kenneth noland
1994
+ kentaro miura
1995
+ bill sienkiewicz
1996
+ keos masons
1997
+ danube school
1998
+ kerembeyit
1999
+ kev walker
2000
+ khalil gibran
2001
+ kieran yanner
2002
+ kilian eng
2003
+ kim keever
2004
+ kim tschang yeul
2005
+ billie waters
2006
+ kinuko craft
2007
+ ecological art
2008
+ kishi ganku
2009
+ kitagawa utamaro
2010
+ kitao shigemasa
2011
+ klimt
2012
+ kobayashi kiyochika
2013
+ excessivism
2014
+ kōno bairei
2015
+ blanche hoschedé monet
2016
+ konrad grob
2017
+ konrad klapheck
2018
+ konrad witz
2019
+ konstantin korovin
2020
+ konstantin makovsky
2021
+ konstantin savitsky
2022
+ konstantin somov
2023
+ konstantin vasilyev
2024
+ konstantin westchilov
2025
+ konstantin yuon
2026
+ konstantinas ciurlionis
2027
+ koson ohara
2028
+ krenz cushart
2029
+ kristian zahrtmann
2030
+ kristin nelson
2031
+ feminist art
2032
+ bob singer
2033
+ kun can
2034
+ kuroda seiki
2035
+ bob thompson
2036
+ kurt schwitters
2037
+ kurt wenner
2038
+ kusama
2039
+ kyffin williams
2040
+ kyle lambert
2041
+ bogi fabian
2042
+ ladrönn
2043
+ lajos bruck
2044
+ lajos gulácsy
2045
+ bohumil kubista
2046
+ lajos tihanyi
2047
+ fluxus
2048
+ lam qua
2049
+ lambert doomer
2050
+ lambert jacobsz
2051
+ lan ying
2052
+ lari pittman
2053
+ larry elmore
2054
+ larry fink
2055
+ larry rivers
2056
+ funk art
2057
+ bonnard pierre
2058
+ lasar segall
2059
+ boris vallejo
2060
+ lászló mednyánszky
2061
+ lászló paál
2062
+ laura ford
2063
+ laura knight
2064
+ laura muntz lyall
2065
+ generative art
2066
+ laura wheeler waring
2067
+ laurel burch
2068
+ laurie lipton
2069
+ laurits tuxen
2070
+ lawren harris
2071
+ boris vladimirski
2072
+ geometric abstract art
2073
+ lawrence harris
2074
+ leandro erlich
2075
+ german romanticism
2076
+ leconte stewart
2077
+ lee jeffries
2078
+ lee madgwick
2079
+ leland bell
2080
+ leng mei
2081
+ lennie lee
2082
+ brad holland
2083
+ leo leuppi
2084
+ léon bakst
2085
+ leon kapliński
2086
+ leon kossoff
2087
+ leon kroll
2088
+ leon wyczółkowski
2089
+ leona wood
2090
+ leonaert bramer
2091
+ leonard appelbee
2092
+ heidelberg school
2093
+ leonard long
2094
+ leonard ochtman
2095
+ leonardo da vinci
2096
+ leonid afremov
2097
+ leonid pasternak
2098
+ leonor fini
2099
+ leonora carrington
2100
+ leopold gottlieb
2101
+ leroy neiman
2102
+ les edwards
2103
+ lesser ury
2104
+ lev lvovich kamenev
2105
+ lewis henry meakin
2106
+ li cheng
2107
+ li chevalier
2108
+ li di
2109
+ li kan
2110
+ li keran
2111
+ li shan
2112
+ li shixing
2113
+ li song
2114
+ li tang
2115
+ li tiefu
2116
+ li zai
2117
+ liam wong
2118
+ liang kai
2119
+ brian bolland
2120
+ lichtenstein
2121
+ incoherents
2122
+ lilia alvarado
2123
+ lilla cabot perry
2124
+ lillian bassman
2125
+ brian despain
2126
+ limbourg brothers
2127
+ lin liang
2128
+ brian dunlop
2129
+ linda sutton
2130
+ lionel lindsay
2131
+ lionel walden
2132
+ lisa frank
2133
+ lisa milroy
2134
+ international gothic
2135
+ lisa yuskavage
2136
+ lise deharme
2137
+ liu haisu
2138
+ liu jun
2139
+ liza donnelly
2140
+ lizzy ansingh
2141
+ lodewijk bruckman
2142
+ lois dodd
2143
+ lois mailou jones
2144
+ lois van baarle
2145
+ loish
2146
+ brian thomas
2147
+ lorenzo lotto
2148
+ lorraine fox
2149
+ lotte reiniger
2150
+ louis anquetin
2151
+ louis buvelot
2152
+ louis comfort tiffany
2153
+ louis de caullery
2154
+ louis eilshemius
2155
+ louis faurer
2156
+ bridget bate tichenor
2157
+ louis grell
2158
+ louis hersent
2159
+ louis janmot
2160
+ louis le brocquy
2161
+ louis le nain
2162
+ bridget riley
2163
+ louis marcoussis
2164
+ louis stettner
2165
+ louis valtat
2166
+ louis wain
2167
+ louisa matthíasdóttir
2168
+ louisa puller
2169
+ louise abbéma
2170
+ louise bourgeois
2171
+ louise catherine breslau
2172
+ louise nevelson
2173
+ lovecraft
2174
+ lovis corinth
2175
+ lu guang
2176
+ lu zhi
2177
+ lubin baugin
2178
+ luc tuymans
2179
+ luca della robbia
2180
+ lucas cranach the elder
2181
+ lucas graciano
2182
+ lucas van leyden
2183
+ lucas vorsterman
2184
+ lucian freud
2185
+ lucien pissarro
2186
+ lucio fontana
2187
+ bruce mclean
2188
+ lucy madox brown
2189
+ ludolf bakhuizen
2190
+ ludolf leendertsz de jongh
2191
+ ludovico carracci
2192
+ bruce munro
2193
+ ludwig bemelmans
2194
+ ludwig knaus
2195
+ luděk marold
2196
+ bruce nauman
2197
+ luigi kasimir
2198
+ luis enrique camej
2199
+ luis royo
2200
+ luo mu
2201
+ luo ping
2202
+ lydia field emmet
2203
+ lyle tuttle
2204
+ bruce timm
2205
+ lyonel feininger
2206
+ lyubov popova
2207
+ m c escher
2208
+ ma lin
2209
+ ma quan
2210
+ ma shi
2211
+ ma wan
2212
+ ma yuan
2213
+ bryan organ
2214
+ mab graves
2215
+ mabel rollins harris
2216
+ mac conner
2217
+ maciej kuciara
2218
+ mads berg
2219
+ maeda masao
2220
+ bunny yeager
2221
+ magali villeneuve
2222
+ byeon sangbyeok
2223
+ makoto aida
2224
+ makoto shinkai
2225
+ byron galvez
2226
+ maksymilian gierymski
2227
+ malcolm drummond
2228
+ malcolm morley
2229
+ malczewski
2230
+ malevich
2231
+ malvin gray johnson
2232
+ man ray
2233
+ caesar van everdingen
2234
+ mandy jurgens
2235
+ marc bell
2236
+ marc chagall
2237
+ marc simonetti
2238
+ marcel duchamp
2239
+ marcello bacciarelli
2240
+ marcin zaleski
2241
+ marco mazzoni
2242
+ neoromanticism
2243
+ marek okon
2244
+ margaret boden
2245
+ margaret graeme niven
2246
+ margaret keane
2247
+ margaret macdonald mackintosh
2248
+ marguerite zorach
2249
+ marià fortuny
2250
+ maria sibylla merian
2251
+ marianne north
2252
+ marianne von werefkin
2253
+ marie angel
2254
+ marie bashkirtseff
2255
+ marie bracquemond
2256
+ marie krøyer
2257
+ marie laurencin
2258
+ marilyn bendell
2259
+ marina abramović
2260
+ mario sironi
2261
+ marion wachtel
2262
+ mariotto albertinelli
2263
+ marius borgeaud
2264
+ mark arian
2265
+ mark boyle
2266
+ mark brooks
2267
+ mark english
2268
+ mark gertler
2269
+ mark keathley
2270
+ mark poole
2271
+ mark rothko
2272
+ mark ryden
2273
+ mark tedin
2274
+ mark zug
2275
+ marsden hartley
2276
+ marshall arisman
2277
+ martin deschambault
2278
+ martin johnson heade
2279
+ martin kober
2280
+ martin schoeller
2281
+ martin schongauer
2282
+ martine johanna
2283
+ martinus rørbye
2284
+ martiros saryan
2285
+ paris school
2286
+ maruyama ōkyo
2287
+ mary adshead
2288
+ mary agnes yerkes
2289
+ mary beale
2290
+ mary black
2291
+ mary blair
2292
+ mary callery
2293
+ mary cameron
2294
+ mary cassatt
2295
+ plasticien
2296
+ mary davis
2297
+ mary dignam
2298
+ mary elizabeth price
2299
+ mary grant
2300
+ mary hallock foote
2301
+ mary mccrossan
2302
+ mary moser
2303
+ masamune shirow
2304
+ masolino
2305
+ mathias kollros
2306
+ mathieu le nain
2307
+ mati klarwein
2308
+ matsumura goshun
2309
+ matt cavotta
2310
+ preraphaelitism
2311
+ matt stewart
2312
+ matt groening
2313
+ matthew smith
2314
+ matthias jung
2315
+ matthias stom
2316
+ matthijs maris
2317
+ mattias adolfsson
2318
+ private press
2319
+ maurice boitel
2320
+ maurice braun
2321
+ maurice de vlaminck
2322
+ maurice denis
2323
+ maude kaufman eggemeyer
2324
+ maurice prendergast
2325
+ maurice sendak
2326
+ maurice utrillo
2327
+ maurycy gottlieb
2328
+ max beckmann
2329
+ max buri
2330
+ max dupain
2331
+ max ernst
2332
+ max gubler
2333
+ max klinger
2334
+ max liebermann
2335
+ max pechstein
2336
+ max slevogt
2337
+ max švabinský
2338
+ qajar art
2339
+ max weber
2340
+ maxfield parrish
2341
+ maxim verehin
2342
+ maximilien luce
2343
+ maxwell bates
2344
+ maxwell gordon lightfoot
2345
+ may louise greville cooksey
2346
+ mc escher
2347
+ mckadesinsanity
2348
+ rayonism
2349
+ mead schaeffer
2350
+ mei qing
2351
+ meindert hobbema
2352
+ melchior dhondecoeter
2353
+ melchior lorck
2354
+ melissa benson
2355
+ melozzo da forlì
2356
+ menez
2357
+ meredith dillman
2358
+ mi fu
2359
+ miao fu
2360
+ michael ancher
2361
+ michael andrews
2362
+ michael cheval
2363
+ michael dahl
2364
+ michael flohr
2365
+ michael ford
2366
+ michael garmash
2367
+ michael goldberg
2368
+ michael james smith
2369
+ michael komarck
2370
+ michael leunig
2371
+ michael malm
2372
+ michael sittow
2373
+ michael whelan
2374
+ michaelangelo
2375
+ michal karcz
2376
+ michał karcz
2377
+ michalis oikonomou
2378
+ michel delacroix
2379
+ michelangelo
2380
+ michelangelo buonarotti
2381
+ michelangelo buonarroti
2382
+ michelangelo merisi da caravaggio
2383
+ michiel jansz van mierevelt
2384
+ michiel van musscher
2385
+ mihály munkácsy
2386
+ mihály zichy
2387
+ miho hirano
2388
+ mikalojus konstantinas ciurlionis
2389
+ serial art
2390
+ mike bierek
2391
+ mike deodato
2392
+ mike mignola
2393
+ mike winkelmann
2394
+ mikhail evstafiev
2395
+ mikhail larionov
2396
+ shock art
2397
+ mikhail nesterov
2398
+ mikhail vrubel
2399
+ mikhail yuryevich lermontov
2400
+ miklós barabás
2401
+ mikhail lebedev
2402
+ mildred anne butler
2403
+ miles johnston
2404
+ millard sheets
2405
+ milton avery
2406
+ milton caniff
2407
+ milton glaser
2408
+ mirabello cavalori
2409
+ mitchell johnson
2410
+ miyamoto
2411
+ miyazaki
2412
+ moebius
2413
+ mœbius
2414
+ moïse kisling
2415
+ mondrian
2416
+ monet
2417
+ morgan russell
2418
+ mori sosen
2419
+ mort künstler
2420
+ moses soyer
2421
+ mstislav dobuzhinsky
2422
+ mucha
2423
+ muirhead bone
2424
+ synchromism
2425
+ munch
2426
+ muqi
2427
+ murakami
2428
+ muriel brandt
2429
+ murray tinkelman
2430
+ synthetism
2431
+ mykola burachek
2432
+ myles birket foster
2433
+ n c wyeth
2434
+ nc wyeth
2435
+ nadim karam
2436
+ nadir afonso
2437
+ nagasawa rosetsu
2438
+ nan goldin
2439
+ nancy graves
2440
+ naoko takeuchi
2441
+ naomi okubo
2442
+ natalia goncharova
2443
+ nathan oliveira
2444
+ nathan wyburn
2445
+ nathaniel hone
2446
+ national geographic
2447
+ naza
2448
+ neal adams
2449
+ neil blevins
2450
+ neil boyle
2451
+ neil welliver
2452
+ neil williams
2453
+ nell dorr
2454
+ nelson alexander ross
2455
+ nene thomas
2456
+ nevercrew
2457
+ neysa mcmein
2458
+ ni zan
2459
+ niccolò dell abbate
2460
+ nicholas hilliard
2461
+ nicholas roerich
2462
+ nick gentry
2463
+ nicola samori
2464
+ nicolaes maes
2465
+ nicolaes pieterszoon berchem
2466
+ nicolas de staël
2467
+ nicolas lancret
2468
+ nicolas poussin
2469
+ nicolas toussaint charlet
2470
+ nicoletta ceccoli
2471
+ nikita veprikov
2472
+ niklaus manuel
2473
+ niko henrichon
2474
+ nikolai astrup
2475
+ nikolai ge
2476
+ nikolai yaroshenko
2477
+ nikolaj abraham abildgaard
2478
+ nikolay makovsky
2479
+ nikolay nikanorovich dubovskoy
2480
+ nil gleyen
2481
+ nils von dardel
2482
+ nina hamnett
2483
+ nishikawa sukenobu
2484
+ noah bradley
2485
+ noel counihan
2486
+ noémi ferenczy
2487
+ norah neilson gray
2488
+ noriyoshi ohrai
2489
+ norma bull
2490
+ norman garstin
2491
+ norman hepple
2492
+ norman lewis
2493
+ norman rockwell
2494
+ norman saunders
2495
+ nuno gonçalves
2496
+ okeeffe
2497
+ odd nerdrum
2498
+ odilon redon
2499
+ ogata gekkō
2500
+ ogata kōrin
2501
+ ohara koson
2502
+ okumura masanobu
2503
+ oleg lipchenko
2504
+ oleg oprisco
2505
+ olga boznańska
2506
+ olha darchuk
2507
+ oliver sin
2508
+ olivia de berardinis
2509
+ olivia peguero
2510
+ orazio gentileschi
2511
+ osamu tezuka
2512
+ oskar kokoschka
2513
+ oskar schlemmer
2514
+ osman hamdi bey
2515
+ ossip zadkine
2516
+ oswald achenbach
2517
+ oswald birley
2518
+ oswaldo guayasamín
2519
+ otakar kubín
2520
+ ottó baditz
2521
+ otto dix
2522
+ otto eckmann
2523
+ otto piene
2524
+ otto pilny
2525
+ otto stark
2526
+ pablo carpio
2527
+ pablo munoz gomez
2528
+ pablo picasso
2529
+ pacita abad
2530
+ pál szinyei merse
2531
+ pamphilus
2532
+ pan yuliang
2533
+ paolo uccello
2534
+ paolo veronese
2535
+ parmigianino
2536
+ pascale campion
2537
+ pat adams
2538
+ patrick adam
2539
+ patrick brown
2540
+ patrick ching
2541
+ patrick dougherty
2542
+ patrick hall
2543
+ patrick henry bruce
2544
+ patrick heron
2545
+ patrick nagel
2546
+ patrick nasmyth
2547
+ patrick pietropoli
2548
+ patrick woodroffe
2549
+ paul bird
2550
+ paul bril
2551
+ paul cadmus
2552
+ paul cezanne
2553
+ paul cézanne
2554
+ paul cornoyer
2555
+ paul davis
2556
+ paul delvaux
2557
+ paul émile chabas
2558
+ paul gauguin
2559
+ paul georges
2560
+ paul guigou
2561
+ paul gustav fischer
2562
+ paul gustave fischer
2563
+ paul harvey
2564
+ paul henry
2565
+ paul jacob naftel
2566
+ paul kane
2567
+ paul kelpe
2568
+ paul klee
2569
+ paul lehr
2570
+ paul lohse
2571
+ paul nash
2572
+ paul ranson
2573
+ paul signac
2574
+ paula rego
2575
+ paulus moreelse
2576
+ paulus potter
2577
+ pavel fedotov
2578
+ pavel filonov
2579
+ pearl frush
2580
+ peder severin krøyer
2581
+ pedro álvarez castelló
2582
+ pedro figari
2583
+ peggy angus
2584
+ peggy bacon
2585
+ penleigh boyd
2586
+ penry williams
2587
+ per kirkeby
2588
+ perle fine
2589
+ peter alexander hay
2590
+ peter basch
2591
+ peter birmann
2592
+ peter blume
2593
+ peter brook
2594
+ peter churcher
2595
+ peter de seve
2596
+ peter doig
2597
+ peter elson
2598
+ peter fiore
2599
+ peter gric
2600
+ peter helck
2601
+ peter lanyon
2602
+ peter lely
2603
+ peter lindbergh
2604
+ peter madsen
2605
+ peter max
2606
+ peter michael
2607
+ peter mohrbacher
2608
+ peter paul rubens
2609
+ peter prendergast
2610
+ peter scott
2611
+ peter snow
2612
+ peter wells
2613
+ peter wtewael
2614
+ peter zumthor
2615
+ petrus christus
2616
+ petrus van der velden
2617
+ phil koch
2618
+ philip de lászló
2619
+ philip evergood
2620
+ philip guston
2621
+ philip wilson steer
2622
+ philipp veit
2623
+ philippe druillet
2624
+ philips wouwerman
2625
+ phillip otto runge
2626
+ picasso
2627
+ piero della francesca
2628
+ piero di cosimo
2629
+ pierre adolphe valette
2630
+ pierre auguste cot
2631
+ pierre bonnard
2632
+ pierre brissaud
2633
+ pierre mion
2634
+ pierre pellegrini
2635
+ pierre puvis de chavannes
2636
+ pierre roy
2637
+ pierre soulages
2638
+ pierreauguste renoir
2639
+ piet mondrian
2640
+ pieter aertsen
2641
+ pieter bruegel
2642
+ pieter brueghel the younger
2643
+ pieter claesz
2644
+ pieter codde
2645
+ pieter cornelisz van slingelandt
2646
+ pieter de grebber
2647
+ pieter de hooch
2648
+ pieter de ring
2649
+ pieter huys
2650
+ pieter janssens elinga
2651
+ pieter jansz saenredam
2652
+ pieter lastman
2653
+ pieter mulier ii
2654
+ pieter van anraedt
2655
+ pieter van der werff
2656
+ pieter van laer
2657
+ pietro da cortona
2658
+ pietro longhi
2659
+ pietro lorenzetti
2660
+ pietro perugino
2661
+ pinchus kremegne
2662
+ pinturicchio
2663
+ piranesi
2664
+ pisanello
2665
+ pixar
2666
+ pollock
2667
+ pompeo batoni
2668
+ prince hoare
2669
+ prudence heward
2670
+ pruett carter
2671
+ pu hua
2672
+ puru
2673
+ qi baishi
2674
+ qian du
2675
+ qian gu
2676
+ qian xuan
2677
+ qiu ying
2678
+ quentin blake
2679
+ quentin matsys
2680
+ quint buchholz
2681
+ r b kitaj
2682
+ r r mcian
2683
+ rachel reckitt
2684
+ rachel ruysch
2685
+ rachel whiteread
2686
+ rackstraw downes
2687
+ radi nedelchev
2688
+ rafail levitsky
2689
+ rafal olbinski
2690
+ raja ravi varma
2691
+ ralph albert blakelock
2692
+ ralph burke tyree
2693
+ ralph earl
2694
+ ralph horsley
2695
+ ralph mcquarrie
2696
+ randolph caldecott
2697
+ randolph schwabe
2698
+ randy gallegos
2699
+ randy post
2700
+ randy vargas
2701
+ raoul dufy
2702
+ raphael
2703
+ raphaël collin
2704
+ raphael kirchner
2705
+ raphael lacoste
2706
+ raphael soyer
2707
+ raphaelle peale
2708
+ ravi zupa
2709
+ ray caesar
2710
+ ray crooke
2711
+ ray parker
2712
+ raymond briggs
2713
+ raymond coxon
2714
+ raymond han
2715
+ raymond leech
2716
+ raymond saunders
2717
+ raymond swanland
2718
+ raymond teague cowern
2719
+ rebecca guay
2720
+ relja penezic
2721
+ rembrandt
2722
+ rembrandt peale
2723
+ rembrandt van rijn
2724
+ remedios varo
2725
+ ren hang
2726
+ ren xiong
2727
+ ren xun
2728
+ rené auberjonois
2729
+ rené burri
2730
+ rene magritte
2731
+ rené magritte
2732
+ renoir
2733
+ reynolds beal
2734
+ rhads
2735
+ ric nagualero
2736
+ ricardo bofill
2737
+ richard anuszkiewicz
2738
+ richard avedon
2739
+ richard benning
2740
+ richard carline
2741
+ richard corben
2742
+ richard dadd
2743
+ richard demarco
2744
+ richard diebenkorn
2745
+ richard doyle
2746
+ richard estes
2747
+ richard gerstl
2748
+ richard hamilton
2749
+ richard hess
2750
+ richard mayhew
2751
+ richard parkes bonington
2752
+ richard pionk
2753
+ richard schmid
2754
+ richard wilson
2755
+ richard wright
2756
+ richmond barthé
2757
+ richter
2758
+ rick amor
2759
+ rick griffin
2760
+ ridley scott
2761
+ ridolfo ghirlandaio
2762
+ rihard jakopič
2763
+ rinaldo cuneo
2764
+ rita angus
2765
+ riusuke fukahori
2766
+ riza abbasi
2767
+ rob alexander
2768
+ rob gonsalves
2769
+ rob liefeld
2770
+ robert adamson
2771
+ robert antoine pinchon
2772
+ robert ballagh
2773
+ robert bateman
2774
+ robert bechtle
2775
+ róbert berény
2776
+ robert bevan
2777
+ robert brackman
2778
+ robert brough
2779
+ robert bryden
2780
+ robert campin
2781
+ robert colquhoun
2782
+ robert crumb
2783
+ robert delaunay
2784
+ robert dickerson
2785
+ robert falk
2786
+ robert fawcett
2787
+ robert freebairn
2788
+ robert gavin
2789
+ robert griffier
2790
+ robert henderson blyth
2791
+ robert henri
2792
+ robert jacobsen
2793
+ robert koehler
2794
+ robert lenkiewicz
2795
+ robert macbryde
2796
+ robert maguire
2797
+ robert mapplethorpe
2798
+ robert mccall
2799
+ robert mcginnis
2800
+ robert motherwell
2801
+ robert noble
2802
+ robert peak
2803
+ robert rauschenberg
2804
+ robert reid
2805
+ robert scott lauder
2806
+ robert sivell
2807
+ robert thomas
2808
+ robert walker macbeth
2809
+ robert weaver
2810
+ robert weir allan
2811
+ robert william vonnoh
2812
+ robert zünd
2813
+ roberto ferri
2814
+ roberto parada
2815
+ rockwell kent
2816
+ rodel gonzalez
2817
+ rodney matthews
2818
+ rodolfo amoedo
2819
+ rodolfo escalera
2820
+ rodolfo morales
2821
+ rodolphe wytsman
2822
+ roelant savery
2823
+ roger ballen
2824
+ roger deakins
2825
+ roger dean
2826
+ roger wilson dennis
2827
+ rogier van der weyden
2828
+ rolf armstrong
2829
+ romaine brooks
2830
+ romare bearden
2831
+ ron english
2832
+ ron spears
2833
+ ron spencer
2834
+ ron walotsky
2835
+ ronald davis
2836
+ rory mcewen
2837
+ rosa bonheur
2838
+ rosalie emslie
2839
+ rose maynard barton
2840
+ rosemary allan
2841
+ ross tran
2842
+ rossdraws
2843
+ rowena meeks abdy
2844
+ roy de maistre
2845
+ roy decarava
2846
+ roy lichtenstein
2847
+ roy petley
2848
+ roz chast
2849
+ ruan jia
2850
+ rube goldberg
2851
+ rubens peale
2852
+ rudolf ernst
2853
+ rudolf hausner
2854
+ rudolf koller
2855
+ rudolf schlichter
2856
+ rudolf von alt
2857
+ rudolph belarski
2858
+ rudy siswanto
2859
+ rufino tamayo
2860
+ rupert bunny
2861
+ russell chatham
2862
+ russell dongjun lu
2863
+ russell drysdale
2864
+ russell patterson
2865
+ ruth hollingsworth
2866
+ ruth orkin
2867
+ ruth sanderson
2868
+ ruth simpson
2869
+ ryan barger
2870
+ ryan pancoast
2871
+ ryan yee
2872
+ ryohei hase
2873
+ ryōhei koiso
2874
+ ryoji ikeda
2875
+ sadao watanabe
2876
+ sailor moon
2877
+ saitō kiyoshi
2878
+ sakai hōitsu
2879
+ salomon de bray
2880
+ salomon koninck
2881
+ salomon van ruysdael
2882
+ salvador dali
2883
+ salvador dalí
2884
+ sam black
2885
+ sam bosma
2886
+ sam charles
2887
+ sam spratt
2888
+ samuel colman
2889
+ samuel dirksz van hoogstraten
2890
+ samuel f b morse
2891
+ samuel peploe
2892
+ samuel prout
2893
+ samuel scott
2894
+ samuel shelley
2895
+ samuel silva
2896
+ sándor bortnyik
2897
+ sandra chevrier
2898
+ sandro botticelli
2899
+ sanford robinson gifford
2900
+ santiago caruso
2901
+ santiago rusiñol
2902
+ sarah lucas
2903
+ sarah morris
2904
+ satoshi kon
2905
+ saul steinberg
2906
+ saul tepper
2907
+ scarlett hooft graafland
2908
+ scott gustafson
2909
+ scott listfield
2910
+ scott naismith
2911
+ sean scully
2912
+ seb mckinnon
2913
+ sebastian vrancx
2914
+ sebastiano ricci
2915
+ sengai
2916
+ senior artist
2917
+ senior environment artist
2918
+ serge sudeikin
2919
+ sergio larraín
2920
+ serhii vasylkivsky
2921
+ sesshū tōyō
2922
+ seuss dr
2923
+ shaddy safadi
2924
+ shang xi
2925
+ shao mi
2926
+ shen quan
2927
+ shen zhou
2928
+ sheng mao
2929
+ sheng maoye
2930
+ shibata zeshin
2931
+ shigeru aoki
2932
+ shin saimdang
2933
+ shin yunbok
2934
+ shinji aramaki
2935
+ shitao
2936
+ shukei sesson
2937
+ sidney nolan
2938
+ sidney richard percy
2939
+ siegfried haas
2940
+ sigurd swane
2941
+ silvestro lega
2942
+ silvia dimitrova
2943
+ silvia pelissero
2944
+ simon bisley
2945
+ simon marmion
2946
+ simon stalenhag
2947
+ simon stålenhag
2948
+ simon vouet
2949
+ simone martini
2950
+ sin wi
2951
+ sir alfred munnings
2952
+ sir jacob epstein
2953
+ sir john tenniel
2954
+ sir william orpen
2955
+ sir william russell flint
2956
+ slawomir maniak
2957
+ sofonisba anguissola
2958
+ sohrab sepehri
2959
+ soma orlai petrich
2960
+ song xu
2961
+ sonia delaunay
2962
+ sophie anderson
2963
+ sophie gengembre anderson
2964
+ sophie pemberton
2965
+ sōtarō yasui
2966
+ sparth
2967
+ spencer gore
2968
+ stan galli
2969
+ stan stokes
2970
+ stanhope forbes
2971
+ stanislas lépine
2972
+ stanislav zhukovsky
2973
+ stanisław ignacy witkiewicz
2974
+ stanisław masłowski
2975
+ stanisław wyspiański
2976
+ stanley artgerm
2977
+ stanley spencer
2978
+ stefan lochner
2979
+ stephan martiniere
2980
+ stephan martinière
2981
+ stephen bone
2982
+ stephen greene
2983
+ stephen little
2984
+ stephen pace
2985
+ stevan dohanos
2986
+ steve argyle
2987
+ steve dillon
2988
+ steve hanks
2989
+ steve mccurry
2990
+ steven belledin
2991
+ stokely webster
2992
+ storm thorgerson
2993
+ stuart davis
2994
+ studio ghibli
2995
+ sudip roy
2996
+ sugimura jihei
2997
+ sun long
2998
+ sung choi
2999
+ sunil das
3000
+ susan crile
3001
+ suzanne valadon
3002
+ suzuki harunobu
3003
+ svetlin velinov
3004
+ svetoslav roerich
3005
+ syd barrett
3006
+ syd mead
3007
+ sydney carline
3008
+ sydney prior hall
3009
+ sylvain sarrailh
3010
+ sylvester shchedrin
3011
+ sylvia molloy
3012
+ sylvia sleigh
3013
+ szymon czechowicz
3014
+ t c steele
3015
+ tadao ando
3016
+ taddeo gaddi
3017
+ tadeusz makowski
3018
+ taiyō matsumoto
3019
+ takahashi yuichi
3020
+ takashi murakami
3021
+ takato yamamoto
3022
+ takehisa yumeji
3023
+ takeshi obata
3024
+ takeuchi seihō
3025
+ tamara de lempicka
3026
+ tamara lempicka
3027
+ tang di
3028
+ tang yifen
3029
+ tang yin
3030
+ tani bunchō
3031
+ taro okamoto
3032
+ taro yamamoto
3033
+ tatiana hordiienko
3034
+ tatsuyuki tanaka
3035
+ tawaraya sōtatsu
3036
+ ted degrazia
3037
+ ted nasmith
3038
+ telemaco signorini
3039
+ terese nielsen
3040
+ terry morris
3041
+ terry oakes
3042
+ terry redlin
3043
+ the brothers hildebrandt
3044
+ thechamba
3045
+ theo van doesburg
3046
+ theodor philipsen
3047
+ théodore chassériau
3048
+ theodore earl butler
3049
+ théodore géricault
3050
+ theodore major
3051
+ theodore robinson
3052
+ théodore rousseau
3053
+ théodule ribot
3054
+ thierry bisch
3055
+ thomas baines
3056
+ thomas barker
3057
+ thomas blackshear
3058
+ thomas bock
3059
+ thomas campbell
3060
+ thomas cantrell dugdale
3061
+ thomas carr
3062
+ thomas cole
3063
+ thomas couture
3064
+ thomas crane
3065
+ thomas dalziel
3066
+ thomas de keyser
3067
+ thomas dewing
3068
+ thomas doughty
3069
+ thomas eakins
3070
+ thomas fogarty
3071
+ thomas gainsborough
3072
+ thomas hart benton
3073
+ thomas hill
3074
+ thomas kinkade
3075
+ thomas kluge
3076
+ thomas lawrence
3077
+ thomas mann baynes
3078
+ thomas millie dow
3079
+ thomas moran
3080
+ thomas nast
3081
+ thomas rowlandson
3082
+ thomas scholes
3083
+ thomas stothard
3084
+ thomas struth
3085
+ thomas wijck
3086
+ thornton oakley
3087
+ tim biskup
3088
+ tim doyle
3089
+ tim hildebrandt
3090
+ tim okamura
3091
+ tim white
3092
+ tina blondell
3093
+ tina modotti
3094
+ tintoretto
3095
+ titian
3096
+ titus lunter
3097
+ todd lockwood
3098
+ tom bagshaw
3099
+ tom bonson
3100
+ tom chambers
3101
+ tom lovell
3102
+ tom phillips
3103
+ tom roberts
3104
+ tom scott rsa
3105
+ tom thomson
3106
+ tom wesselmann
3107
+ tom whalen
3108
+ tomasz alen kopera
3109
+ tomasz jedruszek
3110
+ tomek setowski
3111
+ tomer hanuka
3112
+ tomi ungerer
3113
+ tomioka tessai
3114
+ tommaso masaccio
3115
+ tomokazu matsuyama
3116
+ tony diterlizzi
3117
+ tony sart
3118
+ tooth wu
3119
+ torii kiyomasu
3120
+ torii kiyomitsu
3121
+ torii kiyonaga
3122
+ torii kiyonobu i
3123
+ tosa mitsunobu
3124
+ tosa mitsuoki
3125
+ tōshi yoshida
3126
+ toshiko okanoue
3127
+ tove jansson
3128
+ toyen
3129
+ toyohara chikanobu
3130
+ toyohara kunichika
3131
+ tracey emin
3132
+ tracy harris
3133
+ tran nguyen
3134
+ trevor brown
3135
+ tsuchida bakusen
3136
+ tsuchiya koitsu
3137
+ tsuguharu foujita
3138
+ tsukioka yoshitoshi
3139
+ tuomas korpi
3140
+ tyler edlin
3141
+ tyler jacobson
3142
+ uemura shōen
3143
+ ulrika pasch
3144
+ umberto boccioni
3145
+ unichi hiratsuka
3146
+ urakusai nagahide
3147
+ utagawa hirokage
3148
+ utagawa hiroshige ii
3149
+ utagawa kunimasa
3150
+ utagawa kunisada
3151
+ utagawa kunisada ii
3152
+ utagawa kuniyoshi
3153
+ utagawa toyoharu
3154
+ utagawa toyokuni
3155
+ utagawa yoshiiku
3156
+ utagawa yoshitaki
3157
+ utagawa yoshitora
3158
+ utagawa yoshitsuya
3159
+ václav brožík
3160
+ valentin aleksandrovich serov
3161
+ valentine hugo
3162
+ valerie petts
3163
+ van gogh
3164
+ vanessa beecroft
3165
+ vanessa bell
3166
+ vasily andreevich tropinin
3167
+ vasily perov
3168
+ vasily polenov
3169
+ vasily surikov
3170
+ vasily vereshchagin
3171
+ vassily maximov
3172
+ vermeer
3173
+ vicente juan masip
3174
+ victo ngai
3175
+ victor adame minguez
3176
+ victor brauner
3177
+ victor enrich
3178
+ victor meirelles
3179
+ victor mosquera
3180
+ victor nizovtsev
3181
+ victor vasarely
3182
+ victor wang
3183
+ victoria francés
3184
+ viktor madarász
3185
+ viktor oliva
3186
+ viktor vasnetsov
3187
+ vilhelm kyhn
3188
+ vincent di fate
3189
+ vincent evans
3190
+ vincent lefevre
3191
+ vincent proce
3192
+ vincent van gogh
3193
+ vincenzo cabianca
3194
+ vincenzo irolli
3195
+ viola paterson
3196
+ violet oakley
3197
+ virgil finlay
3198
+ virginia lee burton
3199
+ vito dancona
3200
+ vittore carpaccio
3201
+ vivian maier
3202
+ vladimir borovikovsky
3203
+ vladimir kush
3204
+ vladimir makovsky
3205
+ vladimir tatlin
3206
+ vladimir tretchikoff
3207
+ vlaho bukovac
3208
+ volkan baga
3209
+ wadim kashin
3210
+ waldo peirce
3211
+ walenty wańkowicz
3212
+ wally wood
3213
+ walt disney
3214
+ walt reed
3215
+ walter bayes
3216
+ walter beach humphrey
3217
+ walter crane
3218
+ walter emerson baum
3219
+ walter haskell hinton
3220
+ walter humphrey
3221
+ walter leighton clark
3222
+ walter osborne
3223
+ walter sickert
3224
+ walter stuempfig
3225
+ wang duo
3226
+ wang e
3227
+ wang fu
3228
+ wang hui
3229
+ wang jian
3230
+ wang lü
3231
+ wang meng
3232
+ wang mian
3233
+ wang shimin
3234
+ wang shishen
3235
+ wang wei
3236
+ wang wu
3237
+ wang ximeng
3238
+ wang yi
3239
+ wang yuan
3240
+ wang yuanqi
3241
+ wang zhenpeng
3242
+ warhol
3243
+ warren mahy
3244
+ warwick goble
3245
+ washington allston
3246
+ wassily kandinsky
3247
+ wayne barlowe
3248
+ wayne england
3249
+ wayne reynolds
3250
+ wayne thiebaud
3251
+ weiwei
3252
+ wen boren
3253
+ wen jia
3254
+ wen tong
3255
+ wen zhengming
3256
+ wendell minor
3257
+ wendy froud
3258
+ wes anderson
3259
+ wes wilson
3260
+ wesley burt
3261
+ wifredo lam
3262
+ wilhelm bendz
3263
+ wilhelm leibl
3264
+ wilhelm marstrand
3265
+ wilhelm schnarrenberger
3266
+ wilhelm trübner
3267
+ will barnet
3268
+ will eisner
3269
+ will ellis
3270
+ willard metcalf
3271
+ willem claeszoon heda
3272
+ willem cornelisz duyster
3273
+ willem de kooning
3274
+ willem drost
3275
+ willem kalf
3276
+ willem maris
3277
+ willem van aelst
3278
+ willem van der vliet
3279
+ willem van haecht
3280
+ willem van mieris
3281
+ william berra
3282
+ william blake
3283
+ william blake richmond
3284
+ william bliss baker
3285
+ william bonnar
3286
+ william brodie
3287
+ william coldstream
3288
+ william conor
3289
+ william crosbie
3290
+ william crozier
3291
+ william dargie
3292
+ william dobell
3293
+ william dobson
3294
+ william dring
3295
+ william edouard scott
3296
+ william edward west
3297
+ william etty
3298
+ william fettes douglas
3299
+ william forsyth
3300
+ william gear
3301
+ william george gillies
3302
+ william glackens
3303
+ william gropper
3304
+ william harnett
3305
+ william hoare
3306
+ william hogarth
3307
+ william holman hunt
3308
+ william holmes sullivan
3309
+ william home lizars
3310
+ william jacob baer
3311
+ william jennys
3312
+ william john thomson
3313
+ william kentridge
3314
+ william langson lathrop
3315
+ william mactaggart
3316
+ william mcgregor paxton
3317
+ william mctaggart
3318
+ william merritt chase
3319
+ william michael harnett
3320
+ william miller
3321
+ william morris
3322
+ william nicholson
3323
+ william powhida
3324
+ william quiller orchardson
3325
+ william simpson
3326
+ william steig
3327
+ william stott
3328
+ william stout
3329
+ william trost richards
3330
+ william turner
3331
+ william woodward
3332
+ william york macgregor
3333
+ william zorach
3334
+ williamadolphe bouguereau
3335
+ willian murai
3336
+ willie ito
3337
+ willy finch
3338
+ wilson irvine
3339
+ winona nelson
3340
+ winslow homer
3341
+ winsor mccay
3342
+ winston churchill
3343
+ władysław czachórski
3344
+ władysław podkowiński
3345
+ wlop
3346
+ wojciech gerson
3347
+ wojciech korneli stattler
3348
+ wojciech kossak
3349
+ wojciech weiss
3350
+ wolf huber
3351
+ wolf kahn
3352
+ wolfgang letti
3353
+ wolfgang lettl
3354
+ wouter pietersz crabeth
3355
+ wu bin
3356
+ wu changshuo
3357
+ wu guanzhong
3358
+ wu hong
3359
+ wu li
3360
+ wu shixian
3361
+ wu wei
3362
+ wu zhen
3363
+ wu zuoren
3364
+ wylie beckert
3365
+ wyndham lewis
3366
+ xanthus russell smith
3367
+ xi gang
3368
+ xia chang
3369
+ xia gui
3370
+ xia yong
3371
+ xiang shengmo
3372
+ xiao yuncong
3373
+ xie he
3374
+ xie huan
3375
+ xie sun
3376
+ xu beihong
3377
+ xu wei
3378
+ xu xi
3379
+ xuande emperor
3380
+ xul solar
3381
+ yan hui
3382
+ yan liben
3383
+ yanagawa shigenobu
3384
+ yang j
3385
+ yang jin
3386
+ yanjun cheng
3387
+ yasar vurdem
3388
+ yasuo kuniyoshi
3389
+ yasutomo oka
3390
+ yayoi kusama
3391
+ yayou kusama
3392
+ yerkaland
3393
+ yi jaegwan
3394
+ yoann lossel
3395
+ yoji shinkawa
3396
+ yokoyama taikan
3397
+ yosa buson
3398
+ yoshihiko wada
3399
+ yoshio markino
3400
+ yoshitaka amano
3401
+ yoshitoshi mori
3402
+ yousuf karsh
3403
+ yu zhiding
3404
+ yuan jiang
3405
+ yuan yao
3406
+ yue minjun
3407
+ yuko shimizu
3408
+ yun duseo
3409
+ yun shouping
3410
+ yuri ivanovich pimenov
3411
+ yuumei
3412
+ yves klein
3413
+ yves tanguy
3414
+ yvonne jacquette
3415
+ zack snyder
3416
+ zack stella
3417
+ zaha hadid
3418
+ zdzislaw beksinski
3419
+ zdzisław beksiński
3420
+ zeen chin
3421
+ zeng jing
3422
+ zhang han
3423
+ zhang kechun
3424
+ zhang lu
3425
+ zhang shuqi
3426
+ zhang wo
3427
+ zhang xiaogang
3428
+ zhang xuan
3429
+ zhang yan
3430
+ zhang yin
3431
+ zhang zeduan
3432
+ zhang zongcang
3433
+ zhao mengfu
3434
+ zhao yong
3435
+ zhao zuo
3436
+ zheng xie
3437
+ zhichao cai
3438
+ zhou chen
3439
+ zhou fang
3440
+ zhou jichang
3441
+ zhou wenjing
3442
+ zhu da
3443
+ zhu derun
3444
+ zinaida serebriakova
3445
+ zoë mozert
3446
+ zou yigui
3447
+ zou zhe
3448
+ zsolt bodoni
3449
+ zygmunt waliszewski
3450
+ dustin nguyen
3451
+ e simms campbell
3452
+ e william gollings
3453
+ ed emshwiller
3454
+ ed paschke
3455
+ edi rama
3456
+ edmund f ward
3457
+ édouard detaille
3458
+ édouard vuillard
3459
+ eduardo paolozzi
3460
+ edward bailey
3461
+ edward burnejones
3462
+ edward george handel lucas
3463
+ edward hicks
3464
+ edward okuń
3465
+ edward ruscha
3466
+ edward wadsworth
3467
+ edwin dickinson
3468
+ edwin g lucas
3469
+ eglon van der neer
3470
+ eiichiro oda
3471
+ einar hakonarson
3472
+ elbridge ayer burbank
3473
+ ellen gallagher
3474
+ elsa bleda
3475
+ emil orlik
3476
+ emilio grau sala
3477
+ emily mason
3478
+ emma geary
3479
+ ken elias
3480
+ brice marden
CSD/main_sim.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ # All rights reserved.
5
+
6
+ # This source code is licensed under the license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ import argparse
10
+ import builtins
11
+ import os
12
+ import pathlib
13
+ import random
14
+ import sys
15
+ import warnings
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.parallel
21
+ import torch.backends.cudnn as cudnn
22
+ import torch.distributed as dist
23
+ import torch.optim
24
+ import torch.multiprocessing as mp
25
+ import torch.utils.data
26
+ import torch.utils.data.distributed
27
+ from torchvision import transforms
28
+ import torchvision.models as torchvision_models
29
+ from torchvision.models import VGG16_Weights
30
+
31
+ sys.path.insert(0, str(pathlib.Path(__file__).parent.resolve()))
32
+
33
+ import utils
34
+ from utils import extract_features_pca
35
+ from models import dino_vits, moco_vits
36
+ from data.wikiart import WikiArtD
37
+
38
+
39
+ parser = argparse.ArgumentParser('dynamicDistances-Embedding Generation Module')
40
+ parser.add_argument('--dataset', type=str, required=True, help="Name of the dataset",
41
+ choices=['wikiart'])
42
+
43
+ parser.add_argument('--qsplit', default='query', choices=['query', 'database'], type=str, help="The inferences")
44
+ parser.add_argument('--data-dir', type=str, default=None,
45
+ help='The directory of concerned dataset')
46
+ parser.add_argument('--pt_style', default='csd', type=str)
47
+ parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50')
48
+
49
+ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
50
+ help='number of data loading workers (default: 32)')
51
+ parser.add_argument('-b', '--batch-size', default=64, type=int,
52
+ metavar='N',
53
+ help='mini-batch size (default: 128), this is the total '
54
+ 'batch size of all GPUs on all nodes when '
55
+ 'using Data Parallel or Distributed Data Parallel')
56
+ parser.add_argument('--world-size', default=-1, type=int,
57
+ help='number of nodes for distributed training')
58
+ parser.add_argument('--rank', default=-1, type=int,
59
+ help='node rank for distributed training')
60
+ parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
61
+ help='url used to set up distributed training')
62
+ parser.add_argument('--dist-backend', default='nccl', type=str,
63
+ help='distributed backend')
64
+ parser.add_argument('--seed', default=None, type=int,
65
+ help='seed for initializing training. ')
66
+ parser.add_argument('--gpu', default=None, type=int,
67
+ help='GPU id to use.')
68
+ parser.add_argument('--multiprocessing-distributed', action='store_true',
69
+ help='Use multi-processing distributed training to launch '
70
+ 'N processes per node, which has N GPUs. This is the '
71
+ 'fastest way to use PyTorch for either single node or '
72
+ 'multi node data parallel training')
73
+
74
+ parser.add_argument('--multiscale', default=False, type=utils.bool_flag)
75
+
76
+ # additional configs:
77
+ parser.add_argument('--pretrained', default='', type=str,
78
+ help='path to moco pretrained checkpoint')
79
+ parser.add_argument('--num_loss_chunks', default=1, type=int)
80
+ parser.add_argument('--isvit', action='store_true')
81
+ parser.add_argument('--layer', default=1, type=int, help="layer from end to create descriptors from.")
82
+ parser.add_argument('--feattype', default='normal', type=str, choices=['otprojected', 'weighted', 'concated', 'gram', 'normal'])
83
+ parser.add_argument('--projdim', default=256, type=int)
84
+
85
+ parser.add_argument('-mp', '--model_path', type=str, default=None)
86
+ parser.add_argument('--gram_dims', default=1024, type=int)
87
+ parser.add_argument('--query_count', default=-1, type=int, help='Number of queries to consider for final evaluation. Works only for domainnet')
88
+
89
+ parser.add_argument('--embed_dir', default='./embeddings', type=str, help='Directory to save embeddings')
90
+
91
+ ## Additional config for CSD
92
+ parser.add_argument('--eval_embed', default='head', choices=['head', 'backbone'], help="Which embed to use for eval")
93
+ parser.add_argument('--skip_val', action='store_true')
94
+
95
+
96
+ best_acc1 = 0
97
+
98
+
99
+ def main():
100
+ args = parser.parse_args()
101
+
102
+ if args.seed is not None:
103
+ random.seed(args.seed)
104
+ torch.manual_seed(args.seed)
105
+ cudnn.deterministic = True
106
+ warnings.warn('You have chosen to seed training. '
107
+ 'This will turn on the CUDNN deterministic setting, '
108
+ 'which can slow down your training considerably! '
109
+ 'You may see unexpected behavior when restarting '
110
+ 'from checkpoints.')
111
+ # utils.init_distributed_mode(args)
112
+ if args.gpu is not None:
113
+ warnings.warn('You have chosen a specific GPU. This will completely '
114
+ 'disable data parallelism.')
115
+
116
+ if args.dist_url == "env://" and args.world_size == -1:
117
+ args.world_size = int(os.environ["WORLD_SIZE"])
118
+
119
+ args.distributed = args.world_size > 1 or args.multiprocessing_distributed
120
+
121
+ ngpus_per_node = torch.cuda.device_count()
122
+ if args.multiprocessing_distributed:
123
+ # Since we have ngpus_per_node processes per node, the total world_size
124
+ # needs to be adjusted accordingly
125
+ args.world_size = ngpus_per_node * args.world_size
126
+ # Use torch.multiprocessing.spawn to launch distributed processes: the
127
+ # main_worker process function
128
+ mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
129
+ else:
130
+ # Simply call main_worker function
131
+ main_worker(args.gpu, ngpus_per_node, args)
132
+
133
+
134
+ def main_worker(gpu, ngpus_per_node, args):
135
+ global best_acc1
136
+ args.gpu = gpu
137
+
138
+ # suppress printing if not master
139
+ if args.multiprocessing_distributed and args.gpu != 0:
140
+ def print_pass(*args):
141
+ pass
142
+
143
+ builtins.print = print_pass
144
+
145
+ if args.gpu is not None:
146
+ print("Use GPU: {} for training".format(args.gpu))
147
+
148
+ if args.distributed:
149
+ if args.dist_url == "env://" and args.rank == -1:
150
+ args.rank = int(os.environ["RANK"])
151
+ if args.multiprocessing_distributed:
152
+ # For multiprocessing distributed training, rank needs to be the
153
+ # global rank among all the processes
154
+ args.rank = args.rank * ngpus_per_node + gpu
155
+ dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
156
+ world_size=args.world_size, rank=args.rank)
157
+ torch.distributed.barrier()
158
+
159
+ # create model
160
+ if args.pt_style == 'dino':
161
+ dinomapping = {
162
+ 'vit_base': 'dino_vitb16',
163
+ 'vit_base8': 'dino_vitb8', # TODO: this mapping is incorrect. Change it later
164
+ }
165
+ if args.arch not in dinomapping:
166
+ raise NotImplementedError('This model type does not exist/supported for DINO')
167
+ model = dino_vits.__dict__[dinomapping[args.arch]](
168
+ pretrained=True
169
+ )
170
+ elif args.pt_style == 'moco':
171
+ if args.arch == 'vit_base':
172
+ model = moco_vits.__dict__[args.arch]()
173
+ pretrained = torch.load('./pretrainedmodels/vit-b-300ep.pth.tar', map_location='cpu')
174
+ state_dict = pretrained['state_dict']
175
+ for k in list(state_dict.keys()):
176
+ # retain only base_encoder up to before the embedding layer
177
+ if k.startswith('module.base_encoder'):
178
+ # remove prefix
179
+ state_dict[k[len("module.base_encoder."):]] = state_dict[k]
180
+ # delete renamed or unused k
181
+ del state_dict[k]
182
+ model.load_state_dict(state_dict, strict=False)
183
+ else:
184
+ raise NotImplementedError('This model type does not exist/supported for MoCo')
185
+ elif args.pt_style == 'clip':
186
+ from models import clip
187
+ clipmapping = {
188
+ 'vit_large': 'ViT-L/14',
189
+ 'vit_base': 'ViT-B/16',
190
+ }
191
+ if args.arch not in clipmapping:
192
+ raise NotImplementedError('This model type does not exist/supported for CLIP')
193
+ model, preprocess = clip.load(clipmapping[args.arch])
194
+ elif args.pt_style == 'vgg':
195
+ model = torchvision_models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
196
+ elif args.pt_style == 'sscd':
197
+ if args.arch == 'resnet50':
198
+ model = torch.jit.load("./pretrainedmodels/sscd_disc_mixup.torchscript.pt")
199
+ elif args.arch == 'resnet50_disc':
200
+ model = torch.jit.load("./pretrainedmodels/sscd_disc_large.torchscript.pt")
201
+ else:
202
+ NotImplementedError('This model type does not exist/supported for SSCD')
203
+ elif args.pt_style.startswith('csd'):
204
+ assert args.model_path is not None, "Model path missing for CSD model"
205
+ from CSD.model import CSD_CLIP
206
+ from CSD.utils import has_batchnorms, convert_state_dict
207
+ from CSD.loss_utils import transforms_branch0
208
+
209
+ args.content_proj_head = "default"
210
+ model = CSD_CLIP(args.arch, args.content_proj_head)
211
+ if has_batchnorms(model):
212
+ model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
213
+
214
+ checkpoint = torch.load(args.model_path, map_location="cpu")
215
+ state_dict = convert_state_dict(checkpoint['model_state_dict'])
216
+ msg = model.load_state_dict(state_dict, strict=False)
217
+ print(f"=> loaded checkpoint with msg {msg}")
218
+ preprocess = transforms_branch0
219
+
220
+ if not torch.cuda.is_available():
221
+ print('using CPU, this will be slow')
222
+ elif args.distributed:
223
+ # For multiprocessing distributed, DistributedDataParallel constructor
224
+ # should always set the single device scope, otherwise,
225
+ # DistributedDataParallel will use all available devices.
226
+ if args.gpu is not None:
227
+ torch.cuda.set_device(args.gpu)
228
+ model.cuda(args.gpu)
229
+ # When using a single GPU per process and per
230
+ # DistributedDataParallel, we need to divide the batch size
231
+ # ourselves based on the total number of GPUs we have
232
+ args.batch_size = int(args.batch_size / args.world_size)
233
+ args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
234
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
235
+ else:
236
+ model.cuda()
237
+ # DistributedDataParallel will divide and allocate batch_size to all
238
+ # available GPUs if device_ids are not set
239
+ model = torch.nn.parallel.DistributedDataParallel(model)
240
+ elif args.gpu is not None:
241
+ torch.cuda.set_device(args.gpu)
242
+ model = model.cuda(args.gpu)
243
+ else:
244
+ # DataParallel will divide and allocate batch_size to all available GPUs
245
+ if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
246
+ model.features = torch.nn.DataParallel(model.features)
247
+ model.cuda()
248
+ model = torch.nn.DataParallel(model).cuda()
249
+
250
+ cudnn.benchmark = True
251
+
252
+ # Data loading code
253
+ if args.pt_style == 'clip': # and args.arch == 'resnet50':
254
+ ret_transform = preprocess
255
+ elif args.pt_style.startswith('csd'):
256
+ ret_transform = preprocess
257
+ elif args.pt_style in ['dino', 'moco', 'vgg']:
258
+ ret_transform = transforms.Compose([
259
+ transforms.Resize(256),
260
+ transforms.CenterCrop(224),
261
+ transforms.ToTensor(),
262
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
263
+ ])
264
+ else:
265
+ ret_transform = transforms.Compose([
266
+ transforms.Resize(256),
267
+ transforms.CenterCrop(224),
268
+ transforms.ToTensor(),
269
+ transforms.Normalize([0.5], [0.5]),
270
+ ])
271
+
272
+ if args.dataset == 'wikiart':
273
+ dataset_query = WikiArtD(args.data_dir, args.qsplit, ret_transform)
274
+ dataset_values = WikiArtD(args.data_dir, 'database', ret_transform)
275
+ else:
276
+ raise NotImplementedError
277
+
278
+ ## creating dataloader
279
+ if args.distributed:
280
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset_values, shuffle=False)
281
+ qsampler = torch.utils.data.distributed.DistributedSampler(dataset_query, shuffle=False)
282
+ else:
283
+ sampler = None
284
+ qsampler = None
285
+ data_loader_values = torch.utils.data.DataLoader(
286
+ dataset_values,
287
+ sampler=sampler,
288
+ batch_size=args.batch_size,
289
+ num_workers=args.workers,
290
+ pin_memory=True,
291
+ drop_last=False,
292
+ )
293
+ data_loader_query = torch.utils.data.DataLoader(
294
+ dataset_query,
295
+ sampler=qsampler,
296
+ batch_size=args.batch_size if args.feattype != 'gram' else 32,
297
+ num_workers=args.workers,
298
+ pin_memory=True,
299
+ drop_last=False,
300
+ )
301
+ print(f"train: {len(dataset_values)} imgs / query: {len(dataset_query)} imgs")
302
+ model.eval()
303
+
304
+ ############################################################################
305
+ if not args.multiprocessing_distributed:
306
+ utils.init_distributed_mode(args)
307
+ if args.rank == 0: # only rank 0 will work from now on
308
+
309
+ # Step 1: extract features
310
+ os.makedirs(args.embed_dir, exist_ok=True)
311
+ embsavepath = os.path.join(
312
+ args.embed_dir,
313
+ f'{args.pt_style}_{args.arch}_{args.dataset}_{args.feattype}',
314
+ f'{str(args.layer)}')
315
+ if args.feattype == 'gram':
316
+ path1, path2 = embsavepath.split('_gram')
317
+ embsavepath = '_'.join([path1, 'gram', str(args.gram_dims), args.qsplit, path2])
318
+
319
+ if os.path.isfile(os.path.join(embsavepath, 'database/embeddings_0.pkl')) or args.skip_val:
320
+ valexist = True
321
+ else:
322
+ valexist = False
323
+ if args.feattype == 'gram':
324
+ pca_dirs, meanvals = None, None
325
+ query_features, pca_dirs = extract_features_pca(args, model, pca_dirs, args.gram_dims, data_loader_query,
326
+ False, multiscale=args.multiscale)
327
+ if not valexist:
328
+ values_features, _ = extract_features_pca(args, model, pca_dirs, args.gram_dims, data_loader_values,
329
+ False, multiscale=args.multiscale)
330
+
331
+ elif args.pt_style.startswith('csd'):
332
+ from CSD.utils import extract_features
333
+ query_features = extract_features(model, data_loader_query, use_cuda=False, use_fp16=True, eval_embed=args.eval_embed)
334
+
335
+ if not valexist:
336
+ values_features = extract_features(model, data_loader_values, use_cuda=False, use_fp16=True, eval_embed=args.eval_embed)
337
+ else:
338
+ from utils import extract_features
339
+ query_features = extract_features(args, model, data_loader_query, False, multiscale=args.multiscale)
340
+ if not valexist:
341
+ values_features = extract_features(args, model, data_loader_values, False,
342
+ multiscale=args.multiscale)
343
+
344
+ from search.embeddings import save_chunk
345
+ l_query_features = list(np.asarray(query_features.cpu().detach(), dtype=np.float16))
346
+
347
+ save_chunk(l_query_features, dataset_query.namelist, 0, f'{embsavepath}/{args.qsplit}')
348
+ if not valexist:
349
+ l_values_features = list(np.asarray(values_features.cpu().detach(), dtype=np.float16))
350
+ save_chunk(l_values_features, dataset_values.namelist, 0, f'{embsavepath}/database')
351
+
352
+ print(f'Embeddings saved to: {embsavepath}')
353
+
354
+
355
+ if __name__ == '__main__':
356
+ main()
CSD/metrics/__init__.py ADDED
File without changes
CSD/metrics/metrics.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import sys
3
+
4
+ import numpy as np
5
+
6
+ sys.path.insert(0, str(pathlib.Path(__file__).parent.resolve()))
7
+
8
+
9
+ class Metrics(object):
10
+ def __init__(self):
11
+ self.data = None
12
+
13
+ @staticmethod
14
+ def get_recall(preds, gts, topk=5):
15
+ preds = preds[:, :topk]
16
+ preds -= gts[:, None]
17
+ found = np.where(np.amin(np.absolute(preds), axis=1) == 0)[0]
18
+ return found.shape[0] / gts.shape[0]
19
+
20
+ @staticmethod
21
+ def get_mrr(preds, gts, topk=5):
22
+ preds = preds[:, :topk]
23
+ preds -= gts[:, None]
24
+ rows, cols = np.where(preds == 0)
25
+ _, unique_rows = np.unique(rows, return_index=True)
26
+ valid_cols = cols[unique_rows]
27
+ valid_cols += 1
28
+ return np.mean(1/valid_cols)
29
+
30
+ @staticmethod
31
+ def get_map(preds, gts, topk=5):
32
+ preds = preds[:, :topk]
33
+ preds -= gts[:, None]
34
+ rows, cols = np.where(preds == 0)
35
+ _, unique_rows = np.unique(rows, return_index=True)
36
+ row_cols = np.split(cols, unique_rows)[1:]
37
+ row_cols = [np.hstack([x[0], np.diff(x), topk - x[-1]]) for x in row_cols]
38
+ row_cols = [np.pad(x, (0, topk + 1 - x.shape[0]), 'constant', constant_values=(0, 0)) for x in row_cols]
39
+ precision = np.asarray([np.repeat(np.arange(topk + 1), x) / np.arange(1, topk + 1) for x in row_cols])
40
+ return np.sum(np.mean(precision, axis=1)) / preds.shape[0]
41
+ # numpy increasing array according to bins
42
+
43
+ @staticmethod
44
+ def get_recall_bin(preds, topk=5):
45
+ # preds is a binary matrix of size Q x K
46
+ preds = preds[:, :topk]
47
+ found = np.where(np.amax(preds, axis=1) == True)[0]
48
+ return found.shape[0] / preds.shape[0]
49
+
50
+ @staticmethod
51
+ def get_mrr_bin(preds, topk=5):
52
+ # preds is a binary matrix of size Q x K
53
+ preds = preds[:, :topk]
54
+ rows, cols = np.where(preds)
55
+ _, unique_rows = np.unique(rows, return_index=True)
56
+ valid_cols = cols[unique_rows]
57
+ valid_cols += 1
58
+ return np.mean(1/valid_cols)
59
+
60
+ @staticmethod
61
+ def get_map_bin(preds, topk=5):
62
+ # preds is a binary matrix of size Q x K
63
+ preds = preds[:, :topk]
64
+ rows, cols = np.where(preds)
65
+ _, unique_rows = np.unique(rows, return_index=True)
66
+ row_cols = np.split(cols, unique_rows)[1:]
67
+ row_cols = [np.hstack([x[0], np.diff(x), topk - x[-1]]) for x in row_cols]
68
+ row_cols = [np.pad(x, (0, topk + 1 - x.shape[0]), 'constant', constant_values=(0, 0)) for x in row_cols]
69
+ precision = np.asarray([np.repeat(np.arange(topk + 1), x) / np.arange(1, topk + 1) for x in row_cols])
70
+ return np.sum(np.mean(precision, axis=1)) / preds.shape[0]
71
+
72
+ @staticmethod
73
+ def get_per_query_precision_bin(preds):
74
+ return np.sum(preds, axis=1)/preds.shape[1]
CSD/models/clip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip import *
CSD/models/clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
CSD/models/clip/clip.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Any, Union, List
6
+ from pkg_resources import packaging
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
+ from tqdm import tqdm
12
+
13
+ from .model import build_model
14
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15
+
16
+ try:
17
+ from torchvision.transforms import InterpolationMode
18
+ BICUBIC = InterpolationMode.BICUBIC
19
+ except ImportError:
20
+ BICUBIC = Image.BICUBIC
21
+
22
+
23
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25
+
26
+
27
+ __all__ = ["available_models", "load", "tokenize"]
28
+ _tokenizer = _Tokenizer()
29
+
30
+ _MODELS = {
31
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40
+ }
41
+
42
+
43
+ def _download(url: str, root: str):
44
+ os.makedirs(root, exist_ok=True)
45
+ filename = os.path.basename(url)
46
+
47
+ expected_sha256 = url.split("/")[-2]
48
+ download_target = os.path.join(root, filename)
49
+
50
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
51
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
52
+
53
+ if os.path.isfile(download_target):
54
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55
+ return download_target
56
+ else:
57
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58
+
59
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61
+ while True:
62
+ buffer = source.read(8192)
63
+ if not buffer:
64
+ break
65
+
66
+ output.write(buffer)
67
+ loop.update(len(buffer))
68
+
69
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
70
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
71
+
72
+ return download_target
73
+
74
+
75
+ def _convert_image_to_rgb(image):
76
+ return image.convert("RGB")
77
+
78
+
79
+ def _transform(n_px):
80
+ return Compose([
81
+ Resize(n_px, interpolation=BICUBIC),
82
+ CenterCrop(n_px),
83
+ _convert_image_to_rgb,
84
+ ToTensor(),
85
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
86
+ ])
87
+
88
+
89
+ def available_models() -> List[str]:
90
+ """Returns the names of available CLIP models"""
91
+ return list(_MODELS.keys())
92
+
93
+
94
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
95
+ """Load a CLIP model
96
+
97
+ Parameters
98
+ ----------
99
+ name : str
100
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
101
+
102
+ device : Union[str, torch.device]
103
+ The device to put the loaded model
104
+
105
+ jit : bool
106
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
107
+
108
+ download_root: str
109
+ path to download the model files; by default, it uses "~/.cache/clip"
110
+
111
+ Returns
112
+ -------
113
+ model : torch.nn.Module
114
+ The CLIP model
115
+
116
+ preprocess : Callable[[PIL.Image], torch.Tensor]
117
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
118
+ """
119
+ if name in _MODELS:
120
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
121
+ elif os.path.isfile(name):
122
+ model_path = name
123
+ else:
124
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
125
+
126
+ with open(model_path, 'rb') as opened_file:
127
+ try:
128
+ # loading JIT archive
129
+ model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
130
+ state_dict = None
131
+ except RuntimeError:
132
+ # loading saved state dict
133
+ if jit:
134
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
135
+ jit = False
136
+ state_dict = torch.load(opened_file, map_location="cpu")
137
+
138
+ if not jit:
139
+ model = build_model(state_dict or model.state_dict()).to(device)
140
+ if str(device) == "cpu":
141
+ model.float()
142
+ return model, _transform(model.visual.input_resolution)
143
+
144
+ # patch the device names
145
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
146
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
147
+
148
+ def patch_device(module):
149
+ try:
150
+ graphs = [module.graph] if hasattr(module, "graph") else []
151
+ except RuntimeError:
152
+ graphs = []
153
+
154
+ if hasattr(module, "forward1"):
155
+ graphs.append(module.forward1.graph)
156
+
157
+ for graph in graphs:
158
+ for node in graph.findAllNodes("prim::Constant"):
159
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
160
+ node.copyAttributes(device_node)
161
+
162
+ model.apply(patch_device)
163
+ patch_device(model.encode_image)
164
+ patch_device(model.encode_text)
165
+
166
+ # patch dtype to float32 on CPU
167
+ if str(device) == "cpu":
168
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
169
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
170
+ float_node = float_input.node()
171
+
172
+ def patch_float(module):
173
+ try:
174
+ graphs = [module.graph] if hasattr(module, "graph") else []
175
+ except RuntimeError:
176
+ graphs = []
177
+
178
+ if hasattr(module, "forward1"):
179
+ graphs.append(module.forward1.graph)
180
+
181
+ for graph in graphs:
182
+ for node in graph.findAllNodes("aten::to"):
183
+ inputs = list(node.inputs())
184
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
185
+ if inputs[i].node()["value"] == 5:
186
+ inputs[i].node().copyAttributes(float_node)
187
+
188
+ model.apply(patch_float)
189
+ patch_float(model.encode_image)
190
+ patch_float(model.encode_text)
191
+
192
+ model.float()
193
+
194
+ return model, _transform(model.input_resolution.item())
195
+
196
+
197
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
198
+ """
199
+ Returns the tokenized representation of given input string(s)
200
+
201
+ Parameters
202
+ ----------
203
+ texts : Union[str, List[str]]
204
+ An input string or a list of input strings to tokenize
205
+
206
+ context_length : int
207
+ The context length to use; all CLIP models use 77 as the context length
208
+
209
+ truncate: bool
210
+ Whether to truncate the text in case its encoding is longer than the context length
211
+
212
+ Returns
213
+ -------
214
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
215
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
216
+ """
217
+ if isinstance(texts, str):
218
+ texts = [texts]
219
+
220
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
221
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
222
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
223
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
224
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
225
+ else:
226
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
227
+
228
+ for i, tokens in enumerate(all_tokens):
229
+ if len(tokens) > context_length:
230
+ if truncate:
231
+ tokens = tokens[:context_length]
232
+ tokens[-1] = eot_token
233
+ else:
234
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
235
+ result[i, :len(tokens)] = torch.tensor(tokens)
236
+
237
+ return result
CSD/models/clip/model.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.relu1 = nn.ReLU(inplace=True)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+ self.relu2 = nn.ReLU(inplace=True)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+ self.relu3 = nn.ReLU(inplace=True)
30
+
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(OrderedDict([
37
+ ("-1", nn.AvgPool2d(stride)),
38
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
+ ("1", nn.BatchNorm2d(planes * self.expansion))
40
+ ]))
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ identity = x
44
+
45
+ out = self.relu1(self.bn1(self.conv1(x)))
46
+ out = self.relu2(self.bn2(self.conv2(out)))
47
+ out = self.avgpool(out)
48
+ out = self.bn3(self.conv3(out))
49
+
50
+ if self.downsample is not None:
51
+ identity = self.downsample(x)
52
+
53
+ out += identity
54
+ out = self.relu3(out)
55
+ return out
56
+
57
+
58
+ class AttentionPool2d(nn.Module):
59
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
+ super().__init__()
61
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
+ self.num_heads = num_heads
67
+
68
+ def forward(self, x):
69
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
70
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72
+ x, _ = F.multi_head_attention_forward(
73
+ query=x[:1], key=x, value=x,
74
+ embed_dim_to_check=x.shape[-1],
75
+ num_heads=self.num_heads,
76
+ q_proj_weight=self.q_proj.weight,
77
+ k_proj_weight=self.k_proj.weight,
78
+ v_proj_weight=self.v_proj.weight,
79
+ in_proj_weight=None,
80
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81
+ bias_k=None,
82
+ bias_v=None,
83
+ add_zero_attn=False,
84
+ dropout_p=0,
85
+ out_proj_weight=self.c_proj.weight,
86
+ out_proj_bias=self.c_proj.bias,
87
+ use_separate_proj_weight=True,
88
+ training=self.training,
89
+ need_weights=False
90
+ )
91
+ return x.squeeze(0)
92
+
93
+
94
+ class ModifiedResNet(nn.Module):
95
+ """
96
+ A ResNet class that is similar to torchvision's but contains the following changes:
97
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
98
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
99
+ - The final pooling layer is a QKV attention instead of an average pool
100
+ """
101
+
102
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
103
+ super().__init__()
104
+ self.output_dim = output_dim
105
+ self.input_resolution = input_resolution
106
+
107
+ # the 3-layer stem
108
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
109
+ self.bn1 = nn.BatchNorm2d(width // 2)
110
+ self.relu1 = nn.ReLU(inplace=True)
111
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112
+ self.bn2 = nn.BatchNorm2d(width // 2)
113
+ self.relu2 = nn.ReLU(inplace=True)
114
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
115
+ self.bn3 = nn.BatchNorm2d(width)
116
+ self.relu3 = nn.ReLU(inplace=True)
117
+ self.avgpool = nn.AvgPool2d(2)
118
+
119
+ # residual layers
120
+ self._inplanes = width # this is a *mutable* variable used during construction
121
+ self.layer1 = self._make_layer(width, layers[0])
122
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
125
+
126
+ embed_dim = width * 32 # the ResNet feature dimension
127
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
128
+
129
+ def _make_layer(self, planes, blocks, stride=1):
130
+ layers = [Bottleneck(self._inplanes, planes, stride)]
131
+
132
+ self._inplanes = planes * Bottleneck.expansion
133
+ for _ in range(1, blocks):
134
+ layers.append(Bottleneck(self._inplanes, planes))
135
+
136
+ return nn.Sequential(*layers)
137
+
138
+ def forward(self, x):
139
+ def stem(x):
140
+ x = self.relu1(self.bn1(self.conv1(x)))
141
+ x = self.relu2(self.bn2(self.conv2(x)))
142
+ x = self.relu3(self.bn3(self.conv3(x)))
143
+ x = self.avgpool(x)
144
+ return x
145
+
146
+ x = x.type(self.conv1.weight.dtype)
147
+ x = stem(x)
148
+ x = self.layer1(x)
149
+ x = self.layer2(x)
150
+ x = self.layer3(x)
151
+ x = self.layer4(x)
152
+ x = self.attnpool(x)
153
+
154
+ return x
155
+ def get_intermediate_layers(self, x):
156
+ def stem(x):
157
+ x = self.relu1(self.bn1(self.conv1(x)))
158
+ x = self.relu2(self.bn2(self.conv2(x)))
159
+ x = self.relu3(self.bn3(self.conv3(x)))
160
+ x = self.avgpool(x)
161
+ return x
162
+
163
+ x = x.type(self.conv1.weight.dtype)
164
+ output = []
165
+ x = stem(x)
166
+ output.append(x)
167
+ x = self.layer1(x)
168
+ output.append(x)
169
+ x = self.layer2(x)
170
+ output.append(x)
171
+ x = self.layer3(x)
172
+ output.append(x)
173
+ x = self.layer4(x)
174
+ output.append(x)
175
+ x = self.attnpool(x)
176
+ output.append(x)
177
+ return output
178
+
179
+ class LayerNorm(nn.LayerNorm):
180
+ """Subclass torch's LayerNorm to handle fp16."""
181
+
182
+ def forward(self, x: torch.Tensor):
183
+ orig_type = x.dtype
184
+ ret = super().forward(x.type(torch.float32))
185
+ return ret.type(orig_type)
186
+
187
+
188
+ class QuickGELU(nn.Module):
189
+ def forward(self, x: torch.Tensor):
190
+ return x * torch.sigmoid(1.702 * x)
191
+
192
+
193
+ class ResidualAttentionBlock(nn.Module):
194
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
195
+ super().__init__()
196
+
197
+ self.attn = nn.MultiheadAttention(d_model, n_head)
198
+ self.ln_1 = LayerNorm(d_model)
199
+ self.mlp = nn.Sequential(OrderedDict([
200
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
201
+ ("gelu", QuickGELU()),
202
+ ("c_proj", nn.Linear(d_model * 4, d_model))
203
+ ]))
204
+ self.ln_2 = LayerNorm(d_model)
205
+ self.attn_mask = attn_mask
206
+
207
+ def attention(self, x: torch.Tensor):
208
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
209
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
210
+
211
+ def forward(self, x: torch.Tensor):
212
+ x = x + self.attention(self.ln_1(x))
213
+ x = x + self.mlp(self.ln_2(x))
214
+ return x
215
+
216
+
217
+ class Transformer(nn.Module):
218
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
219
+ super().__init__()
220
+ self.width = width
221
+ self.layers = layers
222
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
223
+
224
+ def forward(self, x: torch.Tensor):
225
+ return self.resblocks(x)
226
+
227
+ def get_activations(self, x: torch.Tensor):
228
+ output = []
229
+
230
+ for i in range(self.layers):
231
+ # import ipdb; ipdb.set_trace()
232
+ x = self.resblocks[i](x)
233
+ output.append(x.permute(1, 0, 2))
234
+ return output
235
+
236
+
237
+
238
+ class VisionTransformer(nn.Module):
239
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
240
+ super().__init__()
241
+ self.input_resolution = input_resolution
242
+ self.output_dim = output_dim
243
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
244
+
245
+ scale = width ** -0.5
246
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
247
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
248
+ self.ln_pre = LayerNorm(width)
249
+
250
+ self.transformer = Transformer(width, layers, heads)
251
+
252
+ self.ln_post = LayerNorm(width)
253
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
254
+
255
+ def forward(self, x: torch.Tensor):
256
+ x = self.conv1(x) # shape = [*, width, grid, grid]
257
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
258
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
259
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
260
+ x = x + self.positional_embedding.to(x.dtype)
261
+ x = self.ln_pre(x)
262
+
263
+ x = x.permute(1, 0, 2) # NLD -> LND
264
+ x = self.transformer(x)
265
+ x = x.permute(1, 0, 2) # LND -> NLD
266
+
267
+ x = self.ln_post(x[:, 0, :])
268
+
269
+ if self.proj is not None:
270
+ x = x @ self.proj
271
+
272
+ return x
273
+
274
+ def get_intermediate_layers(self, x: torch.Tensor):
275
+ x = self.conv1(x) # shape = [*, width, grid, grid]
276
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
277
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
278
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
279
+ x = x + self.positional_embedding.to(x.dtype)
280
+ x = self.ln_pre(x)
281
+
282
+ x = x.permute(1, 0, 2) # NLD -> LND
283
+ # x = self.transformer(x)
284
+ op = self.transformer.get_activations(x)
285
+ # x = x.permute(1, 0, 2) # LND -> NLD
286
+
287
+ # x = self.ln_post(x[:, 0, :])
288
+
289
+ # if self.proj is not None:
290
+ # x = x @ self.proj
291
+ return op
292
+
293
+ class CLIP(nn.Module):
294
+ def __init__(self,
295
+ embed_dim: int,
296
+ # vision
297
+ image_resolution: int,
298
+ vision_layers: Union[Tuple[int, int, int, int], int],
299
+ vision_width: int,
300
+ vision_patch_size: int,
301
+ # text
302
+ context_length: int,
303
+ vocab_size: int,
304
+ transformer_width: int,
305
+ transformer_heads: int,
306
+ transformer_layers: int
307
+ ):
308
+ super().__init__()
309
+
310
+ self.context_length = context_length
311
+
312
+ if isinstance(vision_layers, (tuple, list)):
313
+ vision_heads = vision_width * 32 // 64
314
+ self.visual = ModifiedResNet(
315
+ layers=vision_layers,
316
+ output_dim=embed_dim,
317
+ heads=vision_heads,
318
+ input_resolution=image_resolution,
319
+ width=vision_width
320
+ )
321
+ else:
322
+ vision_heads = vision_width // 64
323
+ self.visual = VisionTransformer(
324
+ input_resolution=image_resolution,
325
+ patch_size=vision_patch_size,
326
+ width=vision_width,
327
+ layers=vision_layers,
328
+ heads=vision_heads,
329
+ output_dim=embed_dim
330
+ )
331
+
332
+ self.transformer = Transformer(
333
+ width=transformer_width,
334
+ layers=transformer_layers,
335
+ heads=transformer_heads,
336
+ attn_mask=self.build_attention_mask()
337
+ )
338
+
339
+ self.vocab_size = vocab_size
340
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
341
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
342
+ self.ln_final = LayerNorm(transformer_width)
343
+
344
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
345
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
346
+
347
+ self.initialize_parameters()
348
+
349
+ def initialize_parameters(self):
350
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
351
+ nn.init.normal_(self.positional_embedding, std=0.01)
352
+
353
+ if isinstance(self.visual, ModifiedResNet):
354
+ if self.visual.attnpool is not None:
355
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
356
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
357
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
358
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
359
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
360
+
361
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
362
+ for name, param in resnet_block.named_parameters():
363
+ if name.endswith("bn3.weight"):
364
+ nn.init.zeros_(param)
365
+
366
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
367
+ attn_std = self.transformer.width ** -0.5
368
+ fc_std = (2 * self.transformer.width) ** -0.5
369
+ for block in self.transformer.resblocks:
370
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
371
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
372
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
373
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
374
+
375
+ if self.text_projection is not None:
376
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
377
+
378
+ def build_attention_mask(self):
379
+ # lazily create causal attention mask, with full attention between the vision tokens
380
+ # pytorch uses additive attention mask; fill with -inf
381
+ mask = torch.empty(self.context_length, self.context_length)
382
+ mask.fill_(float("-inf"))
383
+ mask.triu_(1) # zero out the lower diagonal
384
+ return mask
385
+
386
+ @property
387
+ def dtype(self):
388
+ return self.visual.conv1.weight.dtype
389
+
390
+ def encode_image(self, image):
391
+ return self.visual(image.type(self.dtype))
392
+
393
+ def encode_text(self, text):
394
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
395
+
396
+ x = x + self.positional_embedding.type(self.dtype)
397
+ x = x.permute(1, 0, 2) # NLD -> LND
398
+ x = self.transformer(x)
399
+ x = x.permute(1, 0, 2) # LND -> NLD
400
+ x = self.ln_final(x).type(self.dtype)
401
+
402
+ # x.shape = [batch_size, n_ctx, transformer.width]
403
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
404
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
405
+
406
+ return x
407
+
408
+ def forward(self, image, text):
409
+ image_features = self.encode_image(image)
410
+ text_features = self.encode_text(text)
411
+
412
+ # normalized features
413
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
414
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
415
+
416
+ # cosine similarity as logits
417
+ logit_scale = self.logit_scale.exp()
418
+ logits_per_image = logit_scale * image_features @ text_features.t()
419
+ logits_per_text = logits_per_image.t()
420
+
421
+ # shape = [global_batch_size, global_batch_size]
422
+ return logits_per_image, logits_per_text
423
+
424
+
425
+ def convert_weights(model: nn.Module):
426
+ """Convert applicable model parameters to fp16"""
427
+
428
+ def _convert_weights_to_fp16(l):
429
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
430
+ l.weight.data = l.weight.data.half()
431
+ if l.bias is not None:
432
+ l.bias.data = l.bias.data.half()
433
+
434
+ if isinstance(l, nn.MultiheadAttention):
435
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
436
+ tensor = getattr(l, attr)
437
+ if tensor is not None:
438
+ tensor.data = tensor.data.half()
439
+
440
+ for name in ["text_projection", "proj"]:
441
+ if hasattr(l, name):
442
+ attr = getattr(l, name)
443
+ if attr is not None:
444
+ attr.data = attr.data.half()
445
+
446
+ model.apply(_convert_weights_to_fp16)
447
+
448
+
449
+ def build_model(state_dict: dict):
450
+ vit = "visual.proj" in state_dict
451
+
452
+ if vit:
453
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
454
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
455
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
456
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
457
+ image_resolution = vision_patch_size * grid_size
458
+ else:
459
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
460
+ vision_layers = tuple(counts)
461
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
462
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
463
+ vision_patch_size = None
464
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
465
+ image_resolution = output_width * 32
466
+
467
+ embed_dim = state_dict["text_projection"].shape[1]
468
+ context_length = state_dict["positional_embedding"].shape[0]
469
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
470
+ transformer_width = state_dict["ln_final.weight"].shape[0]
471
+ transformer_heads = transformer_width // 64
472
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
473
+
474
+ model = CLIP(
475
+ embed_dim,
476
+ image_resolution, vision_layers, vision_width, vision_patch_size,
477
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
478
+ )
479
+
480
+ for key in ["input_resolution", "context_length", "vocab_size"]:
481
+ if key in state_dict:
482
+ del state_dict[key]
483
+
484
+ convert_weights(model)
485
+ model.load_state_dict(state_dict)
486
+ return model.eval()
CSD/models/clip/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
CSD/models/dino_vits.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Copied from dino transformers and added the global pool layer
16
+ """
17
+ import math
18
+ from functools import partial
19
+ import warnings
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+
24
+
25
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
26
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
27
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
28
+ def norm_cdf(x):
29
+ # Computes standard normal cumulative distribution function
30
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
31
+
32
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
33
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
34
+ "The distribution of values may be incorrect.",
35
+ stacklevel=2)
36
+
37
+ with torch.no_grad():
38
+ # Values are generated by using a truncated uniform distribution and
39
+ # then using the inverse CDF for the normal distribution.
40
+ # Get upper and lower cdf values
41
+ l = norm_cdf((a - mean) / std)
42
+ u = norm_cdf((b - mean) / std)
43
+
44
+ # Uniformly fill tensor with values from [l, u], then translate to
45
+ # [2l-1, 2u-1].
46
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
47
+
48
+ # Use inverse cdf transform for normal distribution to get truncated
49
+ # standard normal
50
+ tensor.erfinv_()
51
+
52
+ # Transform to proper mean, std
53
+ tensor.mul_(std * math.sqrt(2.))
54
+ tensor.add_(mean)
55
+
56
+ # Clamp to ensure it's in the proper range
57
+ tensor.clamp_(min=a, max=b)
58
+ return tensor
59
+
60
+
61
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
62
+ # type: (Tensor, float, float, float, float) -> Tensor
63
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
64
+
65
+
66
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
67
+ if drop_prob == 0. or not training:
68
+ return x
69
+ keep_prob = 1 - drop_prob
70
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
71
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
72
+ random_tensor.floor_() # binarize
73
+ output = x.div(keep_prob) * random_tensor
74
+ return output
75
+
76
+
77
+ class DropPath(nn.Module):
78
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
79
+ """
80
+ def __init__(self, drop_prob=None):
81
+ super(DropPath, self).__init__()
82
+ self.drop_prob = drop_prob
83
+
84
+ def forward(self, x):
85
+ return drop_path(x, self.drop_prob, self.training)
86
+
87
+
88
+ class Mlp(nn.Module):
89
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
90
+ super().__init__()
91
+ out_features = out_features or in_features
92
+ hidden_features = hidden_features or in_features
93
+ self.fc1 = nn.Linear(in_features, hidden_features)
94
+ self.act = act_layer()
95
+ self.fc2 = nn.Linear(hidden_features, out_features)
96
+ self.drop = nn.Dropout(drop)
97
+
98
+ def forward(self, x):
99
+ x = self.fc1(x)
100
+ x = self.act(x)
101
+ x = self.drop(x)
102
+ x = self.fc2(x)
103
+ x = self.drop(x)
104
+ return x
105
+
106
+
107
+ class Attention(nn.Module):
108
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
109
+ super().__init__()
110
+ self.num_heads = num_heads
111
+ head_dim = dim // num_heads
112
+ self.scale = qk_scale or head_dim ** -0.5
113
+
114
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
115
+ self.attn_drop = nn.Dropout(attn_drop)
116
+ self.proj = nn.Linear(dim, dim)
117
+ self.proj_drop = nn.Dropout(proj_drop)
118
+
119
+ def forward(self, x):
120
+ B, N, C = x.shape
121
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
122
+ q, k, v = qkv[0], qkv[1], qkv[2]
123
+
124
+ attn = (q @ k.transpose(-2, -1)) * self.scale
125
+ attn = attn.softmax(dim=-1)
126
+ attn = self.attn_drop(attn)
127
+
128
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
129
+ x = self.proj(x)
130
+ x = self.proj_drop(x)
131
+ return x, attn
132
+
133
+
134
+ class Block(nn.Module):
135
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
136
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
137
+ super().__init__()
138
+ self.norm1 = norm_layer(dim)
139
+ self.attn = Attention(
140
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
141
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
142
+ self.norm2 = norm_layer(dim)
143
+ mlp_hidden_dim = int(dim * mlp_ratio)
144
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
145
+
146
+ def forward(self, x, return_attention=False):
147
+ y, attn = self.attn(self.norm1(x))
148
+ if return_attention:
149
+ return attn
150
+ x = x + self.drop_path(y)
151
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
152
+ return x
153
+
154
+
155
+ class PatchEmbed(nn.Module):
156
+ """ Image to Patch Embedding
157
+ """
158
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
159
+ super().__init__()
160
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
161
+ self.img_size = img_size
162
+ self.patch_size = patch_size
163
+ self.num_patches = num_patches
164
+
165
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
166
+
167
+ def forward(self, x):
168
+ B, C, H, W = x.shape
169
+ x = self.proj(x).flatten(2).transpose(1, 2)
170
+ return x
171
+
172
+
173
+ class VisionTransformer(nn.Module):
174
+ """ Vision Transformer """
175
+ def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
176
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
177
+ drop_path_rate=0., norm_layer=nn.LayerNorm, global_pool='token',**kwargs):
178
+ super().__init__()
179
+ self.num_features = self.embed_dim = embed_dim
180
+
181
+ self.patch_embed = PatchEmbed(
182
+ img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
183
+ num_patches = self.patch_embed.num_patches
184
+
185
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
186
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
187
+ self.pos_drop = nn.Dropout(p=drop_rate)
188
+
189
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
190
+ self.blocks = nn.ModuleList([
191
+ Block(
192
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
193
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
194
+ for i in range(depth)])
195
+ self.norm = norm_layer(embed_dim)
196
+
197
+ # Classifier head
198
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
199
+
200
+ trunc_normal_(self.pos_embed, std=.02)
201
+ trunc_normal_(self.cls_token, std=.02)
202
+ self.apply(self._init_weights)
203
+
204
+ self.global_pool = global_pool
205
+
206
+ def _init_weights(self, m):
207
+ if isinstance(m, nn.Linear):
208
+ trunc_normal_(m.weight, std=.02)
209
+ if isinstance(m, nn.Linear) and m.bias is not None:
210
+ nn.init.constant_(m.bias, 0)
211
+ elif isinstance(m, nn.LayerNorm):
212
+ nn.init.constant_(m.bias, 0)
213
+ nn.init.constant_(m.weight, 1.0)
214
+
215
+ def interpolate_pos_encoding(self, x, w, h):
216
+ npatch = x.shape[1] - 1
217
+ N = self.pos_embed.shape[1] - 1
218
+ if npatch == N and w == h:
219
+ return self.pos_embed
220
+ class_pos_embed = self.pos_embed[:, 0]
221
+ patch_pos_embed = self.pos_embed[:, 1:]
222
+ dim = x.shape[-1]
223
+ w0 = w // self.patch_embed.patch_size
224
+ h0 = h // self.patch_embed.patch_size
225
+ # we add a small number to avoid floating point error in the interpolation
226
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
227
+ w0, h0 = w0 + 0.1, h0 + 0.1
228
+ patch_pos_embed = nn.functional.interpolate(
229
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
230
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
231
+ mode='bicubic',
232
+ )
233
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
234
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
235
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
236
+
237
+ def prepare_tokens(self, x):
238
+ B, nc, w, h = x.shape
239
+ x = self.patch_embed(x) # patch linear embedding
240
+
241
+ # add the [CLS] token to the embed patch tokens
242
+ cls_tokens = self.cls_token.expand(B, -1, -1)
243
+ x = torch.cat((cls_tokens, x), dim=1)
244
+
245
+ # add positional encoding to each token
246
+ x = x + self.interpolate_pos_encoding(x, w, h)
247
+
248
+ return self.pos_drop(x)
249
+
250
+ def forward(self, x):
251
+ x = self.prepare_tokens(x)
252
+ for blk in self.blocks:
253
+ x = blk(x)
254
+ x = self.norm(x)
255
+ if self.global_pool == 'token':
256
+ return x[:, 0]
257
+ elif self.global_pool == '':
258
+ return x
259
+
260
+ def get_last_selfattention(self, x):
261
+ x = self.prepare_tokens(x)
262
+ for i, blk in enumerate(self.blocks):
263
+ if i < len(self.blocks) - 1:
264
+ x = blk(x)
265
+ else:
266
+ # return attention of the last block
267
+ return blk(x, return_attention=True)
268
+
269
+ def get_intermediate_layers(self, x, n=1):
270
+ x = self.prepare_tokens(x)
271
+ # we return the output tokens from the `n` last blocks
272
+ output = []
273
+ for i, blk in enumerate(self.blocks):
274
+ x = blk(x)
275
+ if len(self.blocks) - i <= n:
276
+ output.append(self.norm(x))
277
+ return output
278
+
279
+
280
+ def vit_tiny(patch_size=16, **kwargs):
281
+ model = VisionTransformer(
282
+ patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
283
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
284
+ return model
285
+
286
+
287
+ def vit_small(patch_size=16, **kwargs):
288
+ model = VisionTransformer(
289
+ patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
290
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
291
+ return model
292
+
293
+
294
+ def vit_base(patch_size=16, **kwargs):
295
+ model = VisionTransformer(
296
+ patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
297
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
298
+ return model
299
+
300
+
301
+ class DINOHead(nn.Module):
302
+ def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
303
+ super().__init__()
304
+ nlayers = max(nlayers, 1)
305
+ if nlayers == 1:
306
+ self.mlp = nn.Linear(in_dim, bottleneck_dim)
307
+ else:
308
+ layers = [nn.Linear(in_dim, hidden_dim)]
309
+ if use_bn:
310
+ layers.append(nn.BatchNorm1d(hidden_dim))
311
+ layers.append(nn.GELU())
312
+ for _ in range(nlayers - 2):
313
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
314
+ if use_bn:
315
+ layers.append(nn.BatchNorm1d(hidden_dim))
316
+ layers.append(nn.GELU())
317
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
318
+ self.mlp = nn.Sequential(*layers)
319
+ self.apply(self._init_weights)
320
+ self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
321
+ self.last_layer.weight_g.data.fill_(1)
322
+ if norm_last_layer:
323
+ self.last_layer.weight_g.requires_grad = False
324
+
325
+ def _init_weights(self, m):
326
+ if isinstance(m, nn.Linear):
327
+ trunc_normal_(m.weight, std=.02)
328
+ if isinstance(m, nn.Linear) and m.bias is not None:
329
+ nn.init.constant_(m.bias, 0)
330
+
331
+ def forward(self, x):
332
+ x = self.mlp(x)
333
+ x = nn.functional.normalize(x, dim=-1, p=2)
334
+ x = self.last_layer(x)
335
+ return x
336
+
337
+
338
+ def dino_vits16(pretrained=True, **kwargs):
339
+ """
340
+ ViT-Small/16x16 pre-trained with DINO.
341
+ Achieves 74.5% top-1 accuracy on ImageNet with k-NN classification.
342
+ """
343
+ model = vit_small(patch_size=16, num_classes=0, **kwargs)
344
+ if pretrained:
345
+ state_dict = torch.hub.load_state_dict_from_url(
346
+ url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth",
347
+ map_location="cpu",
348
+ )
349
+ model.load_state_dict(state_dict, strict=True)
350
+ return model
351
+
352
+
353
+ def dino_vits8(pretrained=True, **kwargs):
354
+ """
355
+ ViT-Small/8x8 pre-trained with DINO.
356
+ Achieves 78.3% top-1 accuracy on ImageNet with k-NN classification.
357
+ """
358
+ model = vit_small(patch_size=8, num_classes=0, **kwargs)
359
+ if pretrained:
360
+ state_dict = torch.hub.load_state_dict_from_url(
361
+ url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth",
362
+ map_location="cpu",
363
+ )
364
+ model.load_state_dict(state_dict, strict=True)
365
+ return model
366
+
367
+
368
+ def dino_vitb16(pretrained=True, **kwargs):
369
+ """
370
+ ViT-Base/16x16 pre-trained with DINO.
371
+ Achieves 76.1% top-1 accuracy on ImageNet with k-NN classification.
372
+ """
373
+ model = vit_base(patch_size=16, num_classes=0, **kwargs)
374
+ if pretrained:
375
+ state_dict = torch.hub.load_state_dict_from_url(
376
+ url="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth",
377
+ map_location="cpu",
378
+ )
379
+ model.load_state_dict(state_dict, strict=True)
380
+ return model
381
+
382
+
383
+ def dino_vitb8(pretrained=True, **kwargs):
384
+ """
385
+ ViT-Base/8x8 pre-trained with DINO.
386
+ Achieves 77.4% top-1 accuracy on ImageNet with k-NN classification.
387
+ """
388
+ model = vit_base(patch_size=8, num_classes=0, **kwargs)
389
+ if pretrained:
390
+ state_dict = torch.hub.load_state_dict_from_url(
391
+ url="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth",
392
+ map_location="cpu",
393
+ )
394
+ model.load_state_dict(state_dict, strict=True)
395
+ return model
396
+
397
+ def dino_vitb_cifar10(pretrained=True, **kwargs):
398
+ """
399
+ ViT-Base/16x16 pre-trained with DINO.
400
+ Achieves 76.1% top-1 accuracy on ImageNet with k-NN classification.
401
+ """
402
+ model = vit_base(patch_size=16, num_classes=0, **kwargs)
403
+ if pretrained:
404
+ state_dict = torch.hub.load_state_dict_from_url(
405
+ url="https://dl.fbaipublicfiles.com/dino/cifar100_ViT_B_dino.pth",
406
+ map_location="cpu",
407
+ )
408
+ model.load_state_dict(state_dict, strict=False)
409
+ return model
410
+
411
+
412
+
413
+
414
+ def dino_resnet50(pretrained=True, **kwargs):
415
+ """
416
+ ResNet-50 pre-trained with DINO.
417
+ Achieves 75.3% top-1 accuracy on ImageNet linear evaluation benchmark (requires to train `fc`).
418
+ """
419
+ from torchvision.models.resnet import resnet50
420
+
421
+ model = resnet50(pretrained=False, **kwargs)
422
+ model.fc = torch.nn.Identity()
423
+ if pretrained:
424
+ state_dict = torch.hub.load_state_dict_from_url(
425
+ url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth",
426
+ map_location="cpu",
427
+ )
428
+ model.load_state_dict(state_dict, strict=False)
429
+ return model
430
+
431
+
432
+ def dino_xcit_small_12_p16(pretrained=True, **kwargs):
433
+ """
434
+ XCiT-Small-12/16 pre-trained with DINO.
435
+ """
436
+ model = torch.hub.load('facebookresearch/xcit:main', "xcit_small_12_p16", num_classes=0, **kwargs)
437
+ if pretrained:
438
+ state_dict = torch.hub.load_state_dict_from_url(
439
+ url="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth",
440
+ map_location="cpu",
441
+ )
442
+ model.load_state_dict(state_dict, strict=True)
443
+ return model
444
+
445
+
446
+ def dino_xcit_small_12_p8(pretrained=True, **kwargs):
447
+ """
448
+ XCiT-Small-12/8 pre-trained with DINO.
449
+ """
450
+ model = torch.hub.load('facebookresearch/xcit:main', "xcit_small_12_p8", num_classes=0, **kwargs)
451
+ if pretrained:
452
+ state_dict = torch.hub.load_state_dict_from_url(
453
+ url="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth",
454
+ map_location="cpu",
455
+ )
456
+ model.load_state_dict(state_dict, strict=True)
457
+ return model
458
+
459
+
460
+ def dino_xcit_medium_24_p16(pretrained=True, **kwargs):
461
+ """
462
+ XCiT-Medium-24/16 pre-trained with DINO.
463
+ """
464
+ model = torch.hub.load('facebookresearch/xcit:main', "xcit_medium_24_p16", num_classes=0, **kwargs)
465
+ if pretrained:
466
+ state_dict = torch.hub.load_state_dict_from_url(
467
+ url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth",
468
+ map_location="cpu",
469
+ )
470
+ model.load_state_dict(state_dict, strict=True)
471
+ return model
472
+
473
+
474
+ def dino_xcit_medium_24_p8(pretrained=True, **kwargs):
475
+ """
476
+ XCiT-Medium-24/8 pre-trained with DINO.
477
+ """
478
+ model = torch.hub.load('facebookresearch/xcit:main', "xcit_medium_24_p8", num_classes=0, **kwargs)
479
+ if pretrained:
480
+ state_dict = torch.hub.load_state_dict_from_url(
481
+ url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth",
482
+ map_location="cpu",
483
+ )
484
+ model.load_state_dict(state_dict, strict=True)
485
+ return model
CSD/models/moco_vits.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import torch
9
+ import torch.nn as nn
10
+ from functools import partial, reduce
11
+ from operator import mul
12
+
13
+ from timm.models.vision_transformer import VisionTransformer, _cfg
14
+ from timm.models.layers.helpers import to_2tuple
15
+ from timm.models.layers import PatchEmbed
16
+
17
+ __all__ = [
18
+ 'vit_small',
19
+ 'vit_base',
20
+ 'vit_conv_small',
21
+ 'vit_conv_base',
22
+ ]
23
+
24
+
25
+ class VisionTransformerMoCo(VisionTransformer):
26
+ def __init__(self, stop_grad_conv1=False, **kwargs):
27
+ super().__init__(**kwargs)
28
+ # Use fixed 2D sin-cos position embedding
29
+ self.build_2d_sincos_position_embedding()
30
+
31
+ # weight initialization
32
+ for name, m in self.named_modules():
33
+ if isinstance(m, nn.Linear):
34
+ if 'qkv' in name:
35
+ # treat the weights of Q, K, V separately
36
+ val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
37
+ nn.init.uniform_(m.weight, -val, val)
38
+ else:
39
+ nn.init.xavier_uniform_(m.weight)
40
+ nn.init.zeros_(m.bias)
41
+ nn.init.normal_(self.cls_token, std=1e-6)
42
+
43
+ if isinstance(self.patch_embed, PatchEmbed):
44
+ # xavier_uniform initialization
45
+ val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim))
46
+ nn.init.uniform_(self.patch_embed.proj.weight, -val, val)
47
+ nn.init.zeros_(self.patch_embed.proj.bias)
48
+
49
+ if stop_grad_conv1:
50
+ self.patch_embed.proj.weight.requires_grad = False
51
+ self.patch_embed.proj.bias.requires_grad = False
52
+
53
+ def build_2d_sincos_position_embedding(self, temperature=10000.):
54
+ h, w = self.patch_embed.grid_size
55
+ grid_w = torch.arange(w, dtype=torch.float32)
56
+ grid_h = torch.arange(h, dtype=torch.float32)
57
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
58
+ assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
59
+ pos_dim = self.embed_dim // 4
60
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
61
+ omega = 1. / (temperature**omega)
62
+ out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
63
+ out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
64
+ pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
65
+
66
+ # assert self.num_tokens == 1, 'Assuming one and only one token, [cls]'
67
+ pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
68
+ self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
69
+ self.pos_embed.requires_grad = False
70
+
71
+
72
+ class ConvStem(nn.Module):
73
+ """
74
+ ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881
75
+ """
76
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
77
+ super().__init__()
78
+
79
+ assert patch_size == 16, 'ConvStem only supports patch size of 16'
80
+ assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem'
81
+
82
+ img_size = to_2tuple(img_size)
83
+ patch_size = to_2tuple(patch_size)
84
+ self.img_size = img_size
85
+ self.patch_size = patch_size
86
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
87
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
88
+ self.flatten = flatten
89
+
90
+ # build stem, similar to the design in https://arxiv.org/abs/2106.14881
91
+ stem = []
92
+ input_dim, output_dim = 3, embed_dim // 8
93
+ for l in range(4):
94
+ stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
95
+ stem.append(nn.BatchNorm2d(output_dim))
96
+ stem.append(nn.ReLU(inplace=True))
97
+ input_dim = output_dim
98
+ output_dim *= 2
99
+ stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
100
+ self.proj = nn.Sequential(*stem)
101
+
102
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
103
+
104
+ def forward(self, x):
105
+ B, C, H, W = x.shape
106
+ assert H == self.img_size[0] and W == self.img_size[1], \
107
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
108
+ x = self.proj(x)
109
+ if self.flatten:
110
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
111
+ x = self.norm(x)
112
+ return x
113
+
114
+
115
+ def vit_small(**kwargs):
116
+ model = VisionTransformerMoCo(
117
+ patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
118
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
119
+ model.default_cfg = _cfg()
120
+ return model
121
+
122
+ def vit_base(**kwargs):
123
+ model = VisionTransformerMoCo(
124
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
125
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
126
+ model.default_cfg = _cfg()
127
+ return model
128
+
129
+ def vit_conv_small(**kwargs):
130
+ # minus one ViT block
131
+ model = VisionTransformerMoCo(
132
+ patch_size=16, embed_dim=384, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True,
133
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs)
134
+ model.default_cfg = _cfg()
135
+ return model
136
+
137
+ def vit_conv_base(**kwargs):
138
+ # minus one ViT block
139
+ model = VisionTransformerMoCo(
140
+ patch_size=16, embed_dim=768, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True,
141
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs)
142
+ model.default_cfg = _cfg()
143
+ return model
CSD/pretrainedmodels/.gitkeep ADDED
File without changes
CSD/search.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import logging
4
+ import logging.handlers as handlers
5
+ import pathlib
6
+ import sys
7
+
8
+ import faiss
9
+ import numpy as np
10
+ import vaex as vx
11
+ import wandb
12
+
13
+ sys.path.insert(0, str(pathlib.Path(__file__).parent.resolve()))
14
+
15
+ from search.embeddings import Embeddings
16
+ from search.faiss_search import FaissIndex
17
+ from metrics import metrics
18
+ from data.wikiart import WikiArt
19
+
20
+ logger = logging.getLogger()
21
+
22
+
23
+ def get_parser():
24
+ parser = argparse.ArgumentParser('dynamicDistances-NN Search Module')
25
+ parser.add_argument('--dataset', default='wikiart', type=str, required=True)
26
+ parser.add_argument('--topk', nargs='+', type=int, default=[5],
27
+ help='Number of NN to consider while calculating recall')
28
+ parser.add_argument('--mode', type=str, required=True, choices=['artist', 'label'],
29
+ help='The type of matching to do')
30
+ parser.add_argument('--method', type=str, default='IP', choices=['IP', 'L2'], help='The method to do NN search')
31
+ parser.add_argument('--emb-dir', type=str, default=None,
32
+ help='The directory where per image embeddings are stored (NOT USED when chunked)')
33
+ parser.add_argument('--query_count', default=-1, type=int,
34
+ help='Number of queries to consider. Works only for domainnet')
35
+ parser.add_argument('--chunked', action='store_true', help='If I should read from chunked directory instead')
36
+ parser.add_argument('--query-chunk-dir', type=str, required=True,
37
+ help='The directory where chunked query embeddings should be saved/are already saved')
38
+ parser.add_argument('--database-chunk-dir', type=str, required=True,
39
+ help='The directory where chunked val embeddings should be saved/are already saved')
40
+ parser.add_argument('--data-dir', type=str, default=None,
41
+ help='The directory of concerned dataset. (HARD CODED LATER)')
42
+ parser.add_argument('--multilabel', action='store_true', help='If the dataset is multilabel')
43
+
44
+ return parser
45
+
46
+
47
+ def get_log_handlers(args):
48
+ # Create handlers
49
+ c_handler = logging.StreamHandler()
50
+ f_handler = handlers.RotatingFileHandler(f'search.log', maxBytes=int(1e6), backupCount=1000)
51
+ c_handler.setLevel(logging.DEBUG)
52
+ f_handler.setLevel(logging.DEBUG)
53
+
54
+ # Create formatters and add it to handlers
55
+ c_format = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
56
+ f_format = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
57
+ c_handler.setFormatter(c_format)
58
+ f_handler.setFormatter(f_format)
59
+ return c_handler, f_handler
60
+
61
+
62
+ def main():
63
+ parser = get_parser()
64
+ args = parser.parse_args()
65
+
66
+ handlers = get_log_handlers(args)
67
+ logger.addHandler(handlers[0])
68
+ logger.addHandler(handlers[1])
69
+ logger.setLevel(logging.DEBUG)
70
+
71
+ if args.dataset == 'wikiart':
72
+ dataset = WikiArt(args.data_dir)
73
+ else:
74
+ raise NotImplementedError
75
+
76
+ query_embeddings = Embeddings(args.emb_dir, args.query_chunk_dir,
77
+ files=list(map(lambda x: f'{x.split(".")[0]}.npy', dataset.query_images)),
78
+ chunked=args.chunked,
79
+ file_ext='.npy')
80
+ val_embeddings = Embeddings(args.emb_dir, args.database_chunk_dir,
81
+ files=list(map(lambda x: f'{x.split(".")[0]}.npy', dataset.val_images)),
82
+ chunked=args.chunked,
83
+ file_ext='.npy')
84
+
85
+ query_embeddings.filenames = list(query_embeddings.filenames)
86
+ val_embeddings.filenames = list(val_embeddings.filenames)
87
+
88
+ # Filtering the dataset based on the files which actually exist.
89
+ dataset.query_db = dataset.query_db[
90
+ dataset.query_db['name'].isin(query_embeddings.filenames)]
91
+ dataset.val_db = dataset.val_db[
92
+ dataset.val_db['name'].isin(val_embeddings.filenames)]
93
+
94
+ # Using only the embeddings corresponding to images in the datasets
95
+ temp = vx.from_arrays(filename=query_embeddings.filenames, index=np.arange(len(query_embeddings.filenames)))
96
+ dataset.query_db = dataset.query_db.join(temp, left_on='name', right_on='filename', how='left')
97
+ query_embeddings.embeddings = query_embeddings.embeddings[dataset.get_query_col('index')]
98
+ try:
99
+ b, h, w = query_embeddings.embeddings.shape
100
+ query_embeddings.embeddings = query_embeddings.embeddings.reshape(b, 1, h * w)
101
+ except ValueError:
102
+ b, d = query_embeddings.embeddings.shape
103
+ query_embeddings.embeddings = query_embeddings.embeddings.reshape(b, 1, d)
104
+ query_embeddings.filenames = np.asarray(query_embeddings.filenames)[dataset.get_query_col('index')]
105
+
106
+ temp = vx.from_arrays(filename=val_embeddings.filenames, index=np.arange(len(val_embeddings.filenames)))
107
+ dataset.val_db = dataset.val_db.join(temp, left_on='name', right_on='filename', how='left')
108
+ val_embeddings.embeddings = val_embeddings.embeddings[dataset.get_val_col('index')]
109
+ try:
110
+ b, h, w = val_embeddings.embeddings.shape
111
+ val_embeddings.embeddings = val_embeddings.embeddings.reshape(b, 1, h * w)
112
+ except ValueError:
113
+ b, d = val_embeddings.embeddings.shape
114
+ val_embeddings.embeddings = val_embeddings.embeddings.reshape(b, 1, d)
115
+ val_embeddings.filenames = np.asarray(val_embeddings.filenames)[dataset.get_val_col('index')]
116
+
117
+ # Building the faiss index
118
+ embedding_size = query_embeddings.embeddings[0].shape[1]
119
+ if args.method == 'IP':
120
+ method = faiss.IndexFlatIP
121
+ else:
122
+ method = faiss.IndexFlatL2
123
+ search_module = FaissIndex(embedding_size=embedding_size, index_func=method)
124
+ queries = np.asarray(query_embeddings.embeddings).reshape(len(query_embeddings.embeddings), embedding_size)
125
+ database = np.asarray(val_embeddings.embeddings).reshape(len(val_embeddings.embeddings), embedding_size)
126
+ search_module.build_index(database)
127
+
128
+ _, nns_all = search_module.search_nns(queries, max(args.topk))
129
+ if args.multilabel:
130
+ q_labels = dataset.query_db['multilabel'].values
131
+ db_labels = dataset.val_db['multilabel'].values
132
+ nns_all_pred = [q_labels[i] @ db_labels[nns_all[i]].T for i in range(len(nns_all))]
133
+ nns_all_pred = np.array(nns_all_pred)
134
+ else:
135
+ nns_all_pred = nns_all
136
+ classes = np.unique(dataset.get_val_col(args.mode))
137
+ mode_to_index = {classname: i for i, classname in enumerate(classes)}
138
+ try:
139
+ gts = np.asarray(list(map(lambda x: mode_to_index[x], dataset.get_query_col(args.mode).tolist())))
140
+ except KeyError:
141
+ logger.error('Class not found in database. This query list cannot be evaluated')
142
+ return
143
+
144
+ evals = metrics.Metrics()
145
+
146
+ for topk in args.topk:
147
+ logger.info(f'Calculating recall@{topk}')
148
+ nns_all_pred_topk = nns_all_pred[:, :topk]
149
+ if args.multilabel:
150
+ mode_recall = evals.get_recall_bin(copy.deepcopy(nns_all_pred_topk), topk)
151
+ mode_mrr = evals.get_mrr_bin(copy.deepcopy(nns_all_pred_topk), topk)
152
+ mode_map = evals.get_map_bin(copy.deepcopy(nns_all_pred_topk), topk)
153
+ else:
154
+ preds = dataset.get_val_col(args.mode)[nns_all_pred_topk.flatten()].reshape(len(queries), topk)
155
+ preds = np.vectorize(mode_to_index.get)(preds)
156
+ mode_recall = evals.get_recall(copy.deepcopy(preds), gts, topk)
157
+ mode_mrr = evals.get_mrr(copy.deepcopy(preds), gts, topk)
158
+ mode_map = evals.get_map(copy.deepcopy(preds), gts, topk)
159
+ logger.info(f'Recall@{topk}: {mode_recall}')
160
+ logger.info(f'MRR@{topk}: {mode_mrr}')
161
+ logger.info(f'mAP@{topk}: {mode_map}')
162
+
163
+
164
+ if __name__ == '__main__':
165
+ main()
CSD/search/__init__.py ADDED
File without changes
CSD/search/embeddings.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures as concfut
2
+ import glob
3
+ import os
4
+
5
+ import pickle
6
+ import logging
7
+ import queue
8
+ import os.path as osp
9
+ import threading
10
+ from multiprocessing import Process
11
+ import math
12
+ import numpy as np
13
+
14
+ module_logger = logging.getLogger(__name__)
15
+
16
+
17
+ class Embeddings(object):
18
+ """Class to read embeddings from the disk and store them in memory"""
19
+ def __init__(self, data_dir, chunk_dir, file_ext='.pt', files=None, chunked=False, chunk_size=5000):
20
+ if files is not None:
21
+ self.embedding_files = list(map(lambda x: osp.join(data_dir, x), files))
22
+ else:
23
+ self.embedding_files = glob.glob(f'{data_dir}/*{file_ext}')
24
+ self.embedding_queue = queue.Queue()
25
+ self.embeddings = []
26
+ self.filenames = []
27
+ self.chunk_dir = chunk_dir
28
+ self.chunk_size = chunk_size
29
+ self.chunked = chunked
30
+ if not self.chunked:
31
+ threading.Thread(target=self.__result_consumer, daemon=True).start()
32
+ self.__read_embeddings()
33
+ self.embeddings, self.filenames = self.__remove_missing(self.embeddings, self.filenames)
34
+ else:
35
+ self.__read_embeddings_chunked()
36
+ self.__sort_embeddings()
37
+
38
+ def __result_consumer(self):
39
+ """Consumes the results from the embedding queue and saves them to the disk"""
40
+ processed = 0
41
+ fnf = 0 # FileNotFound
42
+ embedding_chunk = []
43
+ filename_chunk = []
44
+ chunk_cnt = 0
45
+ while True:
46
+ data = self.embedding_queue.get()
47
+ if not isinstance(data, str):
48
+ self.filenames.append(data['filename'])
49
+ if data['embedding'] is not None:
50
+ self.embeddings.append(data['embedding'])
51
+ processed += 1
52
+ if processed % 1000 == 0:
53
+ module_logger.info(f'Read {processed}/{len(self.embedding_files)} embeddings')
54
+ else:
55
+ fnf += 1
56
+ self.embeddings.append(None)
57
+ if len(embedding_chunk) < self.chunk_size:
58
+ embedding_chunk.append(data['embedding'])
59
+ filename_chunk.append(data['filename'])
60
+ else:
61
+ chunk_cnt += 1
62
+ embedding_chunk, filename_chunk = self.__remove_missing(embedding_chunk, filename_chunk)
63
+ Process(target=save_chunk, args=(embedding_chunk, filename_chunk, chunk_cnt, self.chunk_dir),
64
+ daemon=True).start()
65
+ embedding_chunk = []
66
+ filename_chunk = []
67
+ self.embedding_queue.task_done()
68
+ elif data == 'DONE':
69
+ chunk_cnt += 1
70
+ embedding_chunk, filename_chunk = self.__remove_missing(embedding_chunk, filename_chunk)
71
+ save_chunk(embedding_chunk, filename_chunk, chunk_cnt, self.chunk_dir)
72
+ module_logger.info(
73
+ f'Completed reading embeddings. There were {fnf} images for which embeddings were not found')
74
+ self.embedding_queue.task_done()
75
+ break
76
+
77
+ def __sort_embeddings(self):
78
+ """Sort embeddings and filenames by filename"""
79
+ self.filenames = np.asarray(self.filenames)
80
+ sort_order = np.argsort(self.filenames)
81
+ self.embeddings = np.asarray(self.embeddings)[sort_order]
82
+ self.filenames = self.filenames[sort_order]
83
+
84
+ def __load_embedding(self, filename):
85
+ """Loads an embedding from the disk and puts it in the embedding queue"""
86
+ if osp.exists(filename):
87
+ embedding = np.load(filename)
88
+ data = {
89
+ 'embedding': embedding,
90
+ 'filename': filename.split('/')[-1],
91
+ }
92
+ else:
93
+ data = {
94
+ 'filename': filename.split('/')[-1],
95
+ 'embedding': None
96
+ }
97
+ self.embedding_queue.put(data)
98
+
99
+ def __read_embeddings(self):
100
+ """Reads embeddings from the disk"""
101
+ with concfut.ThreadPoolExecutor(max_workers=32) as executor:
102
+ worker = self.__load_embedding
103
+ executor.map(worker, self.embedding_files)
104
+ executor.shutdown(wait=True, cancel_futures=False)
105
+ self.embedding_queue.put('DONE')
106
+ self.embedding_queue.join()
107
+
108
+ def __read_embeddings_chunked(self):
109
+ """Reads embeddings from the disk in chunks"""
110
+ files = os.listdir(self.chunk_dir)
111
+ cnt = 0
112
+ with concfut.ProcessPoolExecutor(max_workers=32) as executor:
113
+ futures = [executor.submit(load_chunk, osp.join(self.chunk_dir, filename)) for filename in files]
114
+ for future in concfut.as_completed(futures):
115
+ result = future.result()
116
+ module_logger.info(f'Consuming {cnt}/{len(files)} chunks')
117
+ self.embeddings.extend(list(map(lambda x: x.squeeze(), result['embeddings'])))
118
+ self.filenames.extend(list(map(lambda x: '.'.join(x.split('/')[-1].split('.')[:-1]), result['filenames'])))
119
+ cnt += 1
120
+ module_logger.info('Finished reading chunks')
121
+
122
+ @staticmethod
123
+ def get_missing(x):
124
+ """Returns the indices of missing embeddings"""
125
+ indices = filter(lambda i_x: i_x[1] is None, enumerate(x))
126
+ res = np.asarray([i for i, x in indices])
127
+ return res
128
+
129
+ def __remove_missing(self, embeddings, filenames):
130
+ """Removes embeddings and filenames for which embeddings were not found"""
131
+ missing_ids = self.get_missing(embeddings)
132
+ embeddings = [ele for idx, ele in enumerate(embeddings) if idx not in missing_ids]
133
+ filenames = [ele for idx, ele in enumerate(filenames) if idx not in missing_ids]
134
+ return embeddings, filenames
135
+
136
+
137
+ def load_chunk(filename):
138
+ """Loads a chunk file containing embeddings and filenames"""
139
+ data = pickle.load(open(filename, 'rb'))
140
+ return data
141
+
142
+
143
+ def save_chunk(embeddings, filenames, count, chunk_dir, chunk_size=50000):
144
+ """Saves a chunk file containing embeddings and filenames. If the number of embeddings is less than chunk_size, it
145
+ saves all embeddings and filenames in one file. Otherwise, it splits the embeddings and filenames into chunks of
146
+ size chunk_size and saves each chunk in a separate file."""
147
+ assert len(embeddings) == len(filenames)
148
+ os.makedirs(chunk_dir, exist_ok=True)
149
+
150
+ if len(embeddings) < chunk_size:
151
+ data = {
152
+ 'embeddings': embeddings,
153
+ 'filenames': filenames,
154
+ }
155
+ pickle.dump(data, open(osp.join(chunk_dir, f'embeddings_{count}.pkl'), 'wb'))
156
+ else:
157
+ # Split into len(embeddings) / 50000 chunks
158
+ for i in range(0, math.ceil(len(embeddings)/chunk_size)):
159
+ data = {
160
+ 'embeddings': embeddings[i*chunk_size: min((i+1)*chunk_size, len(embeddings))],
161
+ 'filenames': filenames[i*chunk_size: min((i+1)*chunk_size, len(embeddings))],
162
+ }
163
+ with open(osp.join(chunk_dir, f'embeddings_{i}.pkl'), 'wb') as f:
164
+ pickle.dump(data, f)
CSD/search/faiss_search.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import faiss
3
+
4
+ module_logger = logging.getLogger(__name__)
5
+
6
+
7
+ class FaissIndex(object):
8
+ def __init__(self, index_func=faiss.IndexFlatIP, embedding_size=512*512):
9
+ self.index = index_func(embedding_size)
10
+ # Enable GPU support
11
+ # self.index_gpu = faiss.index_cpu_to_all_gpus(self.index)
12
+
13
+ def build_index(self, nodes):
14
+ self.index.add(nodes)
15
+ # Enable GPU support
16
+ # self.index_gpu.add(nodes)
17
+
18
+ def search_nns(self, embeddings, n):
19
+ # Enable GPU support
20
+ # return self.index_gpu.search(embeddings, n)
21
+ return self.index.search(embeddings, n)
CSD/utils.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Code elements borrowed from
3
+ https://github.com/clovaai/CutMix-PyTorch/blob/master/train.py
4
+ '''
5
+ import argparse
6
+ import os
7
+ import sys
8
+ from collections import defaultdict, deque
9
+ import time, datetime
10
+
11
+ import faiss
12
+ import numpy as np
13
+ import torch
14
+ import torch.distributed as dist
15
+ import torch.nn as nn
16
+
17
+ from einops import rearrange, reduce
18
+
19
+
20
+ def is_dist_avail_and_initialized():
21
+ if not dist.is_available():
22
+ return False
23
+ if not dist.is_initialized():
24
+ return False
25
+ return True
26
+
27
+
28
+ def bool_flag(s):
29
+ """
30
+ Parse boolean arguments from the command line.
31
+ """
32
+ FALSY_STRINGS = {"off", "false", "0"}
33
+ TRUTHY_STRINGS = {"on", "true", "1"}
34
+ if s.lower() in FALSY_STRINGS:
35
+ return False
36
+ elif s.lower() in TRUTHY_STRINGS:
37
+ return True
38
+ else:
39
+ raise argparse.ArgumentTypeError("invalid value for a boolean flag")
40
+
41
+
42
+ def setup_for_distributed(is_master):
43
+ """
44
+ This function disables printing when not in master process
45
+ """
46
+ import builtins as __builtin__
47
+ builtin_print = __builtin__.print
48
+
49
+ def print(*args, **kwargs):
50
+ force = kwargs.pop('force', False)
51
+ if is_master or force:
52
+ builtin_print(*args, **kwargs)
53
+
54
+ __builtin__.print = print
55
+
56
+
57
+ def init_distributed_mode(args):
58
+ args.distributed = True
59
+ # launched with torch.distributed.launch
60
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
61
+ args.rank = int(os.environ["RANK"])
62
+ args.world_size = int(os.environ['WORLD_SIZE'])
63
+ args.gpu = int(os.environ['LOCAL_RANK'])
64
+ # launched with submitit on a slurm cluster
65
+ elif 'SLURM_PROCID' in os.environ:
66
+ args.rank = int(os.environ['SLURM_PROCID'])
67
+ args.gpu = args.rank % torch.cuda.device_count()
68
+ # launched naively with `python main_dino.py`
69
+ # we manually add MASTER_ADDR and MASTER_PORT to env variables
70
+ elif torch.cuda.is_available():
71
+ print('Will run the code on one GPU.')
72
+ args.rank, args.gpu, args.world_size = 0, 0, 1
73
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
74
+ os.environ['MASTER_PORT'] = '29500'
75
+ else:
76
+ print('Does not support training without GPU.')
77
+ sys.exit(1)
78
+
79
+ dist.init_process_group(
80
+ backend="nccl",
81
+ init_method=args.dist_url,
82
+ world_size=args.world_size,
83
+ rank=args.rank,
84
+ )
85
+
86
+ torch.cuda.set_device(args.gpu)
87
+ print('| distributed init (rank {}): {}'.format(
88
+ args.rank, args.dist_url), flush=True)
89
+ dist.barrier()
90
+ setup_for_distributed(args.rank == 0)
91
+
92
+
93
+ class SmoothedValue(object):
94
+ """Track a series of values and provide access to smoothed values over a
95
+ window or the global series average.
96
+ """
97
+
98
+ def __init__(self, window_size=20, fmt=None):
99
+ if fmt is None:
100
+ fmt = "{median:.6f} ({global_avg:.6f})"
101
+ self.deque = deque(maxlen=window_size)
102
+ self.total = 0.0
103
+ self.count = 0
104
+ self.fmt = fmt
105
+
106
+ def update(self, value, n=1):
107
+ self.deque.append(value)
108
+ self.count += n
109
+ self.total += value * n
110
+
111
+ def synchronize_between_processes(self):
112
+ """
113
+ Warning: does not synchronize the deque!
114
+ """
115
+ if not is_dist_avail_and_initialized():
116
+ return
117
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
118
+ dist.barrier()
119
+ dist.all_reduce(t)
120
+ t = t.tolist()
121
+ self.count = int(t[0])
122
+ self.total = t[1]
123
+
124
+ @property
125
+ def median(self):
126
+ d = torch.tensor(list(self.deque))
127
+ return d.median().item()
128
+
129
+ @property
130
+ def avg(self):
131
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
132
+ return d.mean().item()
133
+
134
+ @property
135
+ def global_avg(self):
136
+ return self.total / self.count
137
+
138
+ @property
139
+ def max(self):
140
+ return max(self.deque)
141
+
142
+ @property
143
+ def value(self):
144
+ return self.deque[-1]
145
+
146
+ def __str__(self):
147
+ return self.fmt.format(
148
+ median=self.median,
149
+ avg=self.avg,
150
+ global_avg=self.global_avg,
151
+ max=self.max,
152
+ value=self.value)
153
+
154
+
155
+ class MetricLogger(object):
156
+ def __init__(self, delimiter="\t"):
157
+ self.meters = defaultdict(SmoothedValue)
158
+ self.delimiter = delimiter
159
+
160
+ def update(self, **kwargs):
161
+ for k, v in kwargs.items():
162
+ if isinstance(v, torch.Tensor):
163
+ v = v.item()
164
+ assert isinstance(v, (float, int))
165
+ self.meters[k].update(v)
166
+
167
+ def __getattr__(self, attr):
168
+ if attr in self.meters:
169
+ return self.meters[attr]
170
+ if attr in self.__dict__:
171
+ return self.__dict__[attr]
172
+ raise AttributeError("'{}' object has no attribute '{}'".format(
173
+ type(self).__name__, attr))
174
+
175
+ def __str__(self):
176
+ loss_str = []
177
+ for name, meter in self.meters.items():
178
+ loss_str.append(
179
+ "{}: {}".format(name, str(meter))
180
+ )
181
+ return self.delimiter.join(loss_str)
182
+
183
+ def synchronize_between_processes(self):
184
+ for meter in self.meters.values():
185
+ meter.synchronize_between_processes()
186
+
187
+ def add_meter(self, name, meter):
188
+ self.meters[name] = meter
189
+
190
+ def log_every(self, iterable, print_freq, header=None):
191
+ i = 0
192
+ if not header:
193
+ header = ''
194
+ start_time = time.time()
195
+ end = time.time()
196
+ iter_time = SmoothedValue(fmt='{avg:.6f}')
197
+ data_time = SmoothedValue(fmt='{avg:.6f}')
198
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
199
+ if torch.cuda.is_available():
200
+ log_msg = self.delimiter.join([
201
+ header,
202
+ '[{0' + space_fmt + '}/{1}]',
203
+ 'eta: {eta}',
204
+ '{meters}',
205
+ 'time: {time}',
206
+ 'data: {data}',
207
+ 'max mem: {memory:.0f}'
208
+ ])
209
+ else:
210
+ log_msg = self.delimiter.join([
211
+ header,
212
+ '[{0' + space_fmt + '}/{1}]',
213
+ 'eta: {eta}',
214
+ '{meters}',
215
+ 'time: {time}',
216
+ 'data: {data}'
217
+ ])
218
+ MB = 1024.0 * 1024.0
219
+ for obj in iterable:
220
+ data_time.update(time.time() - end)
221
+ yield obj
222
+ iter_time.update(time.time() - end)
223
+ if i % print_freq == 0 or i == len(iterable) - 1:
224
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
225
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
226
+ if torch.cuda.is_available():
227
+ print(log_msg.format(
228
+ i, len(iterable), eta=eta_string,
229
+ meters=str(self),
230
+ time=str(iter_time), data=str(data_time),
231
+ memory=torch.cuda.max_memory_allocated() / MB))
232
+ else:
233
+ print(log_msg.format(
234
+ i, len(iterable), eta=eta_string,
235
+ meters=str(self),
236
+ time=str(iter_time), data=str(data_time)))
237
+ i += 1
238
+ end = time.time()
239
+ total_time = time.time() - start_time
240
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
241
+ print('{} Total time: {} ({:.6f} s / it)'.format(
242
+ header, total_time_str, total_time / len(iterable)))
243
+
244
+
245
+ def multi_scale(samples, model, args):
246
+ v = None
247
+ for s in [1, 1 / 2 ** (1 / 2), 1 / 2]: # we use 3 different scales
248
+ if s == 1:
249
+ inp = samples.clone()
250
+ else:
251
+ inp = torch.nn.functional.interpolate(samples, scale_factor=s, mode='bilinear', align_corners=False)
252
+
253
+ if args.pt_style == 'vicregl':
254
+ feats = model(inp)[-1].clone()
255
+ elif args.pt_style == 'clip':
256
+ feats = model.module.encode_image(samples).to(torch.float32).clone()
257
+ else:
258
+ feats = model(inp).clone()
259
+ feats = torch.squeeze(feats)
260
+ feats = torch.unsqueeze(feats, 0)
261
+ if v is None:
262
+ v = feats
263
+ else:
264
+ v += feats
265
+ v /= 3
266
+ v /= v.norm()
267
+ return v
268
+
269
+
270
+ def patchify(x, size):
271
+ patches = rearrange(x, 'b c (h1 h2) (w1 w2) -> (b h1 w1) c h2 w2', h2=size, w2=size)
272
+ return patches
273
+
274
+
275
+ @torch.no_grad()
276
+ def extract_features(args, model, data_loader, use_cuda=True, multiscale=False):
277
+ metric_logger = MetricLogger(delimiter=" ")
278
+ features = None
279
+ # count = 0
280
+ for samples, index in metric_logger.log_every(data_loader, 100):
281
+ print(f'At the index {index[0]}')
282
+ samples = samples.cuda(non_blocking=True)
283
+ index = index.cuda(non_blocking=True)
284
+ if multiscale:
285
+ feats = multi_scale(samples, model, args)
286
+ else:
287
+
288
+ if args.pt_style == 'dino':
289
+ if args.layer > 1:
290
+ feats = model.module.get_intermediate_layers(samples, args.layer)[0][:, 0, :].clone()
291
+ elif args.layer == -1:
292
+
293
+ allfeats = model.module.get_intermediate_layers(samples, len(model.module.blocks))
294
+ feats = [allfeats[i - 1][:, 0, :] for i in args.multilayer]
295
+ bdim, _ = feats[0].shape
296
+ feats = torch.stack(feats, dim=1).reshape((bdim, -1)).clone()
297
+ else:
298
+ feats = model(samples).clone()
299
+
300
+ elif args.pt_style == 'moco':
301
+ feats = model.module.forward_features(samples)
302
+ feats = feats[:, 0, :].clone()
303
+ elif args.pt_style == 'vgg':
304
+ feats = model.module.features(samples).clone()
305
+ elif args.pt_style in ['clip', 'clip_wikiart']:
306
+ #
307
+ allfeats = model.module.visual.get_intermediate_layers(samples.type(model.module.dtype))
308
+ # else:
309
+ # allfeats = model.get_activations(samples) #[::-1]
310
+ allfeats.reverse()
311
+
312
+ if args.arch == 'resnet50':
313
+ # import ipdb; ipdb.set_trace()
314
+ if args.layer == -1:
315
+ raise Exception('Layer=-1 not allowed with clip resnet')
316
+ elif args.layer == 1:
317
+ feats = allfeats[0].clone()
318
+ else:
319
+ assert len(allfeats) >= args.layer, "Asking for features of layer that doesnt exist"
320
+ feats = reduce(allfeats[args.layer - 1], 'b c h w -> b c', 'mean').clone()
321
+
322
+ else:
323
+ if args.layer == -1:
324
+ feats = [allfeats[i - 1][:, 0, :] for i in args.multilayer]
325
+ bdim, _ = feats[0].shape
326
+ feats = torch.stack(feats, dim=1).reshape((bdim, -1)).clone()
327
+ else:
328
+ assert len(allfeats) >= args.layer
329
+ feats = allfeats[args.layer - 1][:, 0, :].clone()
330
+ else:
331
+ feats = model(samples).clone()
332
+ # init storage feature matrix
333
+ feats = nn.functional.normalize(feats, dim=1, p=2).to(torch.float16)
334
+ if dist.get_rank() == 0 and features is None:
335
+ features = torch.zeros(len(data_loader.dataset), feats.shape[-1], dtype=feats.dtype)
336
+ if use_cuda:
337
+ features = features.cuda(non_blocking=True)
338
+ print(f"Storing features into tensor of shape {features.shape}")
339
+ # get indexes from all processes
340
+ y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device)
341
+ y_l = list(y_all.unbind(0))
342
+ y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
343
+ y_all_reduce.wait()
344
+ index_all = torch.cat(y_l)
345
+
346
+ # share features between processes
347
+ feats_all = torch.empty(
348
+ dist.get_world_size(),
349
+ feats.size(0),
350
+ feats.size(1),
351
+ dtype=feats.dtype,
352
+ device=feats.device,
353
+ )
354
+ output_l = list(feats_all.unbind(0))
355
+ output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
356
+ output_all_reduce.wait()
357
+
358
+ # update storage feature matrix
359
+ if dist.get_rank() == 0:
360
+ if use_cuda:
361
+ features.index_copy_(0, index_all, torch.cat(output_l).cuda())
362
+ else:
363
+ features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
364
+
365
+ return features
366
+
367
+
368
+ def extract_features_pca(args, model, pca_model, k, data_loader, use_cuda=True, multiscale=False):
369
+ metric_logger = MetricLogger(delimiter=" ")
370
+ features = None
371
+ print('In pca function')
372
+ for samples, index in metric_logger.log_every(data_loader, 100):
373
+ print(f'At the index {index[0]}')
374
+ samples = samples.cuda(non_blocking=True)
375
+ index = index.cuda(non_blocking=True)
376
+
377
+ if multiscale:
378
+ feats = multi_scale(samples, model, args)
379
+ else:
380
+
381
+ if args.pt_style in ['clip', 'clip_wikiart']:
382
+ allfeats = model.module.visual.get_intermediate_layers(samples.type(model.module.dtype))
383
+ allfeats.reverse()
384
+ if args.arch == 'resnet50':
385
+ raise Exception('code not written for this case')
386
+ else:
387
+ temp = allfeats[args.layer - 1]
388
+ temp = torch.nn.functional.normalize(temp, dim=2)
389
+ # Doing gram matrix
390
+ feats = torch.einsum('bij,bik->bjk', temp, temp)
391
+ feats = feats.div(temp.shape[1])
392
+ feats = rearrange(feats, 'b c d -> b (c d)')
393
+ if pca_model is not None:
394
+ feats = feats.cpu().detach().numpy()
395
+ feats = pca_model.apply_py(feats)
396
+ feats = torch.from_numpy(feats).cuda().clone()
397
+ else:
398
+ feats = feats.detach().clone()
399
+ del temp
400
+ del allfeats
401
+ elif args.pt_style == 'vgg':
402
+ temp = model.module.features(samples)
403
+ temp = temp.view(temp.size(0), temp.size(1), -1)
404
+ feats = torch.einsum('bji,bki->bjk', temp, temp)
405
+ feats = feats.div(temp.shape[1])
406
+ feats = rearrange(feats, 'b c d -> b (c d)')
407
+ if pca_model is not None:
408
+ feats = feats.cpu().detach().numpy()
409
+ feats = pca_model.apply_py(feats)
410
+ feats = torch.from_numpy(feats).cuda().clone()
411
+ else:
412
+ feats = feats.detach().clone()
413
+ del temp
414
+ else:
415
+ raise Exception('Code not written for these ptstyles. Come back later.')
416
+
417
+ feats = nn.functional.normalize(feats, dim=1, p=2).to(torch.float16)
418
+ # init storage feature matrix
419
+ if dist.get_rank() == 0 and features is None:
420
+ features = torch.zeros(len(data_loader.dataset), feats.shape[-1], dtype=feats.dtype)
421
+ if use_cuda:
422
+ features = features.cuda(non_blocking=True)
423
+ print(f"Storing features into tensor of shape {features.shape}")
424
+
425
+ # get indexes from all processes
426
+ y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device)
427
+ y_l = list(y_all.unbind(0))
428
+ y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
429
+ y_all_reduce.wait()
430
+ index_all = torch.cat(y_l)
431
+
432
+ # share features between processes
433
+ feats_all = torch.empty(
434
+ dist.get_world_size(),
435
+ feats.size(0),
436
+ feats.size(1),
437
+ dtype=feats.dtype,
438
+ device=feats.device,
439
+ )
440
+ output_l = list(feats_all.unbind(0))
441
+ output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
442
+ output_all_reduce.wait()
443
+
444
+ # update storage feature matrix
445
+ if dist.get_rank() == 0:
446
+ if use_cuda:
447
+ features.index_copy_(0, index_all, torch.cat(output_l))
448
+ else:
449
+ features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
450
+ if pca_model is None:
451
+ features = features.detach().numpy()
452
+ pca = faiss.PCAMatrix(features.shape[-1], k)
453
+ pca.train(features)
454
+ trans_features = pca.apply_py(features)
455
+ return torch.from_numpy(trans_features), pca
456
+ else:
457
+ return features, None
458
+
459
+
460
+ # saving features into numpy files
461
+ def save_embeddings_numpy(embeddings, filenames, savepath):
462
+ os.makedirs(savepath, exist_ok=True)
463
+ for c, fname in enumerate(filenames):
464
+ np_emb = np.asarray(embeddings[c, :].cpu().detach(), dtype=np.float16)
465
+ np.save(f'{savepath}/{fname}.npy', np_emb)
CSD/wikiart.csv ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -1,3 +1,11 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+
5
+ ## Measuring Style Similarity in Diffusion Models
6
+
7
+ Cloned from [learn2phoenix/CSD](https://github.com/learn2phoenix/CSD?tab=readme-ov-file).
8
+
9
+ Their model (`csd-vit-l.pth`) downloaded from their [Google Drive](https://drive.google.com/file/d/1FX0xs8p-C7Ob-h5Y4cUhTeOepHzXv_46/view?usp=sharing).
10
+
11
+ The original Git Repo is in the `CSD` folder.
csd-vit-l.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40e92fad63a361b8136100cd234c42d401ef9b34ff1748234318929ebcc7e7a1
3
+ size 2438228893