Spaces:
Runtime error
Runtime error
# Date: 2023-03-14 | |
# Creater: zejunyang | |
# Function: 边缘注意力层。 | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from NTED.base_function import Blur | |
class ResBlock(nn.Module): | |
def __init__(self, in_nc, out_nc, scale='down'): # , norm_layer=nn.BatchNorm2d | |
super(ResBlock, self).__init__() | |
use_bias = True | |
assert scale in ['up', 'down', 'same'], "ResBlock scale must be in 'up' 'down' 'same'" | |
if scale == 'same': | |
# self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=1, bias=True) | |
self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=True) | |
if scale == 'up': | |
self.scale = nn.Sequential( | |
nn.Upsample(scale_factor=2, mode='bilinear'), | |
nn.Conv2d(in_nc, out_nc, kernel_size=1,bias=True) | |
) | |
if scale == 'down': | |
self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=2, padding=1, bias=use_bias) | |
self.block = nn.Sequential( | |
nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias), | |
# norm_layer(out_nc), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias), | |
# norm_layer(out_nc) | |
) | |
self.relu = nn.ReLU(inplace=True) | |
# self.padding = nn.ReplicationPad2d(padding=(0, 1, 0, 0)) | |
def forward(self, x): | |
residual = self.scale(x) | |
return self.relu(residual + self.block(residual)) | |
class Edge_Attn(nn.Module): | |
def __init__(self, in_channels=3): | |
super(Edge_Attn, self).__init__() | |
self.in_channels = in_channels | |
blur_kernel=[1, 3, 3, 3, 1] | |
self.blur = Blur(blur_kernel, pad=(2, 2), upsample_factor=1) | |
# self.conv = nn.Conv2d(self.in_channels, self.in_channels, 3, padding=1, bias=False) | |
self.res_block = ResBlock(self.in_channels, self.in_channels, scale='same') | |
self.sigmoid = nn.Sigmoid() | |
def gradient(self, x): | |
h_x = x.size()[2] | |
w_x = x.size()[3] | |
stride = 3 | |
r = F.pad(x, (0, stride, 0, 0), mode='replicate')[:, :, :, stride:] | |
l = F.pad(x, (stride, 0, 0, 0), mode='replicate')[:, :, :, :w_x] | |
t = F.pad(x, (0, 0, stride, 0), mode='replicate')[:, :, :h_x, :] | |
b = F.pad(x, (0, 0, 0, stride), mode='replicate')[:, :, stride:, :] | |
xgrad = torch.pow(torch.pow((r - l) * 0.5, 2) + torch.pow((t - b) * 0.5, 2), 0.5) | |
xgrad = self.blur(xgrad) | |
return xgrad | |
def forward(self, x): | |
# feature_edge = self.gradient(x).detach() | |
# attn = self.conv(feature_edge) | |
for b in range(x.shape[0]): | |
for c in range(x.shape[1]): | |
if c == 0: | |
channel_edge = self.gradient(x[b:b+1, c:c+1]) | |
else: | |
channel_edge = torch.concat([channel_edge, self.gradient(x[b:b+1, c:c+1])], dim=1) | |
if b == 0: | |
feature_edge = channel_edge | |
else: | |
feature_edge = torch.concat([feature_edge, channel_edge], dim=0) | |
feature_edge = feature_edge.detach() | |
feature_edge = x * feature_edge | |
attn = self.res_block(feature_edge) | |
attn = self.sigmoid(attn) | |
# out = x * attn | |
out = x * attn + x | |
return out | |
if __name__ == '__main__': | |
from PIL import Image | |
import numpy as np | |
import cv2 | |
edg_atten = Edge_Attn() | |
im = Image.open('/apdcephfs/share_1474453/zejunzhang/dataset/pose_trans_dataset/fake_images/001400.png') | |
npim = np.array(im,dtype=np.float32) | |
npim = cv2.cvtColor(npim, cv2.COLOR_RGB2GRAY) | |
# npim = npim[:, :, 2] | |
tim = torch.from_numpy(npim).unsqueeze_(0).unsqueeze_(0) | |
edge = edg_atten.gradient(tim) | |
npgrad = edge.squeeze(0).squeeze(0).data.clamp(0,255).numpy() | |
Image.fromarray(npgrad.astype('uint8')).save('tmp.png') | |
# tim = torch.from_numpy(npim).unsqueeze_(0) | |
# edge = edg_atten.gradient_1order(tim) | |
# npgrad = edge.squeeze(0).data.clamp(0,255).numpy()[:, :, 0] | |
# Image.fromarray(npgrad.astype('uint8')).save('tmp.png') | |