File size: 5,027 Bytes
cddd431
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from PIL import Image, ImageOps, ImageFilter
import random

def add_normalization_to_transform(unnormalized_transforms):
    """Adds ImageNet normalization to all transforms"""
    normalized_transform = {}
    for key, value in unnormalized_transforms.items():
        normalized_transform[key] = transforms.Compose([value, 
                                                        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                                                             std=[0.229, 0.224, 0.225])]) 
    return normalized_transform

def modify_transforms(normal_transforms, no_shift_transforms, ig_transforms):
    normal_transforms = add_normalization_to_transform(normal_transforms)
    no_shift_transforms = add_normalization_to_transform(no_shift_transforms)
    ig_transforms = add_normalization_to_transform(ig_transforms)
    return normal_transforms, no_shift_transforms, ig_transforms
    
class Solarization(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            return ImageOps.solarize(img)
        else:
            return img
        
# no imagent normalization for simclrv2
pure_transform = transforms.Compose([transforms.Resize(256), 
                                     transforms.CenterCrop(224), 
                                     transforms.ToTensor()])   

aug_transform = transforms.Compose([transforms.RandomResizedCrop(224), 
                                    transforms.RandomHorizontalFlip(p=0.5),
                                    transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
                                    transforms.RandomGrayscale(p=0.2),
                                    transforms.RandomApply([transforms.GaussianBlur(kernel_size=(21,21), sigma=(0.1,2.0))], p=0.5),
                                    transforms.ToTensor()])  

ig_pure_transform = transforms.Compose([transforms.Resize(256), 
                                        transforms.CenterCrop(224), 
                                        transforms.ToTensor()])   

ig_transform_colorjitter = transforms.Compose([transforms.Resize(256), 
                                               transforms.CenterCrop(224),
                                               transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.4)], p=1),
                                               transforms.ToTensor()])  

ig_transform_blur = transforms.Compose([transforms.Resize(256), 
                                        transforms.CenterCrop(224),
                                        transforms.RandomApply([transforms.GaussianBlur(kernel_size=(11,11), sigma=(5,5))], p=1),
                                        transforms.ToTensor()])  

ig_transform_solarize = transforms.Compose([transforms.Resize(256), 
                                            transforms.CenterCrop(224),
                                            Solarization(p=1.0),
                                            transforms.ToTensor()]) 

ig_transform_grayscale = transforms.Compose([transforms.Resize(256), 
                                             transforms.CenterCrop(224),
                                             transforms.RandomGrayscale(p=1),
                                             transforms.ToTensor()])  


ig_transform_combine = transforms.Compose([transforms.Resize(256), 
                                           transforms.CenterCrop(224),
                                           transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
                                           transforms.RandomGrayscale(p=0.2),
                                           transforms.RandomApply([transforms.GaussianBlur(kernel_size=(21,21), sigma=(0.1, 2.0))], p=0.5),
                                           transforms.ToTensor()])  

pure_transform_no_shift = transforms.Compose([transforms.Resize((224, 224)), 
                                              transforms.ToTensor()])   

aug_transform_no_shift = transforms.Compose([transforms.Resize((224, 224)),
                                             transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
                                             transforms.RandomGrayscale(p=0.2),
                                             transforms.ToTensor()])  

normal_transforms = {'pure': pure_transform, 
                     'aug': aug_transform}

no_shift_transforms = {'pure': pure_transform_no_shift, 
                       'aug': aug_transform_no_shift}

ig_transforms = {'pure': ig_pure_transform, 
                 'color_jitter': ig_transform_colorjitter, 
                 'blur': ig_transform_blur, 
                 'grayscale': ig_transform_grayscale, 
                 'solarize': ig_transform_solarize,
                 'combine': ig_transform_combine}