Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
import PIL.Image | |
import torch | |
import cv2, albumentations | |
import numpy as np | |
def save_image(img, filename): | |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) | |
PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(filename) | |
def save_image_grid(img, fname, drange, grid_size): | |
lo, hi = drange | |
img = np.asarray(img, dtype=np.float32) | |
img = (img - lo) * (255 / (hi - lo)) | |
img = np.rint(img).clip(0, 255).astype(np.uint8) | |
gw, gh = grid_size | |
_N, C, H, W = img.shape | |
img = img.reshape(gh, gw, C, H, W) | |
img = img.transpose(0, 3, 1, 4, 2) | |
img = img.reshape(gh * H, gw * W, C) | |
assert C in [1, 3] | |
if C == 1: | |
PIL.Image.fromarray(img[:, :, 0], 'L').save(fname) | |
if C == 3: | |
PIL.Image.fromarray(img, 'RGB').save(fname) | |
def resize_image(img_pytorch, curr_res): | |
img = img_pytorch.permute(0,2,3,1).cpu().numpy() | |
img = [albumentations.geometric.functional.resize( | |
img[i], height=curr_res, width=curr_res, interpolation=cv2.INTER_LANCZOS4) | |
for i in range(img.shape[0])] | |
img = torch.from_numpy(np.stack(img, axis=0)).permute(0,3,1,2).to(img_pytorch.device) | |
return img |