File size: 3,553 Bytes
ff715ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
826f0e1
ff715ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import torch.nn.functional as F

def numpy2tensor(img):
    x0 = torch.from_numpy(img.copy()).float().cuda() / 255.0 * 2.0 - 1.
    x0 = torch.stack([x0], dim=0)
    # einops.rearrange(x0, 'b h w c -> b c h w').clone()
    return x0.permute(0, 3, 1, 2)

def pil2tensor(img):
    return numpy2tensor(np.array(img))

def tensor2numpy(img):
    image = (img / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    return images

def tensor2pil(img):
    return Image.fromarray(tensor2numpy(img)[0])

def cv2sod(img):
    in_ = np.array(img, dtype=np.float32)
    in_ -= np.array((104.00699, 116.66877, 122.67892))
    in_ = in_.transpose((2,0,1))
    image = torch.Tensor(in_)
    return F.interpolate(image.unsqueeze(0), scale_factor=0.5, mode='bilinear')

def get_frame_count(video_path: str):
    video = cv2.VideoCapture(video_path)
    frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    video.release()
    return frame_count

def resize_image(input_image, resolution):
    H, W, C = input_image.shape
    H = float(H)
    W = float(W)
    k = float(resolution) / max(H, W)
    H *= k
    W *= k
    H = int(np.round(H / 64.0)) * 64
    W = int(np.round(W / 64.0)) * 64
    img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
    return img

def visualize(img_arr, dpi):
    plt.figure(figsize=(10,10),dpi=dpi)
    plt.imshow(((img_arr.detach().cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8))
    plt.axis('off')
    plt.show()


def calc_mean_std(feat, eps=1e-5, chunk=1):
    size = feat.size()
    assert (len(size) == 4)
    if chunk == 2:
        feat = torch.cat(feat.chunk(2), dim=3)
    N, C = size[:2]
    feat_var = feat.view(N//chunk, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N//chunk, C, -1).mean(dim=2).view(N//chunk, C, 1, 1)
    return feat_mean.repeat(chunk,1,1,1), feat_std.repeat(chunk,1,1,1)


def adaptive_instance_normalization(content_feat, style_feat, chunk=1):
    assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat, chunk)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)


class Dilate():
    def __init__(self, kernel_size=7, channels=1, device='cpu'):
        self.kernel_size=kernel_size
        self.channels = channels
        gaussian_kernel = torch.ones(1, 1, self.kernel_size, self.kernel_size)
        gaussian_kernel = gaussian_kernel.repeat(self.channels, 1, 1, 1)
        self.mean = (self.kernel_size - 1)//2
        gaussian_kernel = gaussian_kernel.to(device)
        self.gaussian_filter = gaussian_kernel
        
    def __call__(self, x):
        x = F.pad(x, (self.mean,self.mean,self.mean,self.mean), "replicate")
        return torch.clamp(F.conv2d(x, self.gaussian_filter, bias=None), 0, 1)

@torch.no_grad()
def get_saliency(imgs, sod_model, dilate):
    imgs_sod = torch.cat([cv2sod(img) for img in imgs], dim=0).cuda()
    _, _, up_sal_f = sod_model(imgs_sod)
    saliency = 1-dilate(np.squeeze(torch.sigmoid(up_sal_f[-1])).unsqueeze(1))
    del up_sal_f
    torch.cuda.empty_cache()
    return saliency