PyTorch Translation and implementation of DeePSim generators

Generating Images with Perceptual Similarity Metrics based on Deep Networks Alexey Dosovitskiy, Thomas Brox (2016)

  • Network architecture is translated from the original implementation in caffe using pytorch-caffe repo.
  • The network definitions in pure torch were included in GAN_utils.py
  • Weights are translated from the pre-trained caffe weights from Alexey Dosovitskiy's homepage, and saved as torch state dict.
  • This repo contains pre-trained state dicts of 9 generative models and a classification model (caffenet). The generative models are trained to invert the representation from various layers in the caffenet (norm1, norm2, conv3, conv4, pool5, fc6, fc7, fc8). All these models are relatively simple, consisting of linear, conv and deconvolution layers.

Example usage

from GAN_utils import Caffenet, upconvGAN
CNN = Caffenet(pretrained=True)
layer = "conv3"
invert_layer_id = 9
G = upconvGAN(name=layer, pretrained=True)
img = Image.open(...)
img = img.resize((227, 227))
img = np.array(img)
RGB_mean = torch.tensor([123.0, 117.0, 104.0])
RGB_mean = torch.reshape(RGB_mean, (1, 3, 1, 1))
img_preproc = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float() # 
img_preproc = (img_preproc - RGB_mean)[:, [2, 1, 0], :, :]

with torch.no_grad():
    out = CNN.net[:invert_layer_id + 1](img_preproc.cuda())
    imgtsr_recon_pp = G.visualize(out).cpu()

plt.subplots(1, 2, figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(imgtsr_recon_pp[0].permute(1, 2, 0).detach().cpu().numpy())
plt.title(f"Reconstructed-{layer}")
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(img)
plt.title("Original")
plt.axis("off")
plt.show()

Output: image/png

Reconstruction of ImageNet validation set images (1st column) using norm1 to fc8 generators from corresponding layer representation. image/png

To our understanding, the mapping of generative model and the layer number is the following.

invers_layer_map = {
    "norm1": 3,
    "norm2": 7,
    "conv3": 9,
    "conv4": 11,
    "pool5": 14,
    "fc6": 17,
    "fc6_eucl": 17,
    "fc7": 19,
    "fc8": 20,
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .

Dataset used to train binxu/DeePSim_DosovitskiyBrox2016