Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import imageio | |
import numpy as np | |
import glob | |
import sys | |
from typing import Any | |
sys.path.insert(1, '.') | |
import argparse | |
from pytorch_lightning import seed_everything | |
from PIL import Image | |
import torch | |
from operators import GaussialBlurOperator | |
from utils import get_rank | |
from torchvision.ops import masks_to_boxes | |
from matfusion import MateralDiffusion | |
from loguru import logger | |
__MAX_BATCH__ = 4 # 4 for A10 | |
def init_model(ckpt_path, ddim, gpu_id): | |
# find config | |
configs = os.listdir(f'{ckpt_path}/configs') | |
model_config = [config for config in configs if "project.yaml" in config][0] | |
sds_loss_class = MateralDiffusion(device=gpu_id, fp16=True, | |
config=f'{ckpt_path}/configs/{model_config}', | |
ckpt=f'{ckpt_path}/checkpoints/last.ckpt', vram_O=False, | |
t_range=[0.001, 0.02], opt=None, use_ddim=ddim) | |
return sds_loss_class | |
def images_spliter(image, seg_h, seg_w, padding_pixel, padding_val, overlaps=1): | |
# split the input images along height and weidth by | |
# return a list of images | |
h, w, c = image.shape | |
h = h - (h%(seg_h*overlaps)) | |
w = w - (w%(seg_w*overlaps)) | |
h_crop = h // seg_h | |
w_crop = w // seg_w | |
images = [] | |
positions = [] | |
img_padded = torch.zeros(h+padding_pixel*2, w+padding_pixel*2, 3, device=image.device) + padding_val | |
img_padded[padding_pixel:h+padding_pixel, padding_pixel:w+padding_pixel, :] = image[:h, :w] | |
# overlapped sampling | |
seg_h = np.round((h - h_crop) / h_crop * overlaps).astype(int) + 1 | |
seg_w = np.round((w - w_crop) / w_crop * overlaps).astype(int) + 1 | |
h_step = np.round(h_crop / overlaps).astype(int) | |
w_step = np.round(w_crop / overlaps).astype(int) | |
# print(f"h_step: {h_step}, seg_h: {seg_h}, w_step: {w_step}, seg_w: {seg_w}, img_padded: {img_padded.shape}, image[:h, :w]: {image[:h, :w].shape}") | |
for ind_i in range(0,seg_h): | |
i = ind_i * h_step | |
for ind_j in range(0,seg_w): | |
j = ind_j * w_step | |
img_ = img_padded[i:i+h_crop+padding_pixel*2, j:j+w_crop+padding_pixel*2, :] | |
images.append(img_) | |
positions.append(torch.FloatTensor([i-padding_pixel, j-padding_pixel]).reshape(2)) | |
return torch.stack(images, dim=0), torch.stack(positions, dim=0), seg_h, seg_w | |
class InferenceModel(): | |
def __init__(self, ckpt_path, use_ddim, gpu_id=0): | |
self.model = init_model(ckpt_path, use_ddim, gpu_id=gpu_id) | |
self.gpu_id = gpu_id | |
self.split_hw = [1,1] | |
self.padding = 0 | |
self.padding_crop = 0 | |
self.results_list = None | |
self.results_output_list = [] | |
self.image_sizes_list = [] | |
def parse_item(self, img_ori, mask_img_ori, guid_images): | |
# if mask_img_ori is None: | |
# mask_img_ori = read_img(input_name, read_alpha=True) | |
# # ensure background is white, same as training data | |
# img_ori[~(mask_img_ori[..., 0] > 0.5)] = 1 | |
img_ori[~(mask_img_ori[..., 0] > 0.5)] = 1 | |
use_true_mask = (self.split_hw[0] * self.split_hw[1]) <= 1 | |
self.ori_hw = list(img_ori.shape) | |
# mask cropping | |
min_max_uv = masks_to_boxes(mask_img_ori[None, ..., -1] > 0.5).long() | |
self.min_uv, self.max_uv = min_max_uv[0, ..., [1,0]], min_max_uv[0, ..., [3,2]]+1 | |
# print(self.min_uv, self.max_uv) | |
mask_img = mask_img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] | |
img = img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] | |
image_size = list(img.shape) | |
if not use_true_mask: | |
# for cropping boarder | |
self.max_uv[0] = self.max_uv[0] - ((self.max_uv[0]-self.min_uv[0])%(self.split_hw[0]*self.split_overlap)) | |
self.max_uv[1] = self.max_uv[1] - ((self.max_uv[1]-self.min_uv[1])%(self.split_hw[1]*self.split_overlap)) | |
mask_img = mask_img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] | |
img = img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] | |
image_size = list(img.shape) | |
if not use_true_mask: | |
mask_img = torch.ones_like(mask_img) | |
mask_img, _ = images_spliter(mask_img[..., [0, 0, 0]], self.split_hw[0], self.split_hw[1], self.padding, not use_true_mask, self.split_overlap)[:2] | |
img, position_indexes, seg_h, seg_w = images_spliter(img, self.split_hw[0], self.split_hw[1], self.padding, 1, self.split_overlap) | |
self.split_hw_overlapped = [seg_h, seg_w] | |
logger.info(f"Spliting Size: {image_size}, splits: {self.split_hw}, Overlapped: {self.split_hw_overlapped}") | |
if guid_images is None: | |
guid_images = torch.zeros_like(img) | |
else: | |
guid_images = guid_images[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] | |
guid_images, _ = images_spliter(guid_images, self.split_hw[0], self.split_hw[1], self.padding, 1, self.split_overlap)[:2] | |
return guid_images, img, mask_img[..., :1], image_size, position_indexes | |
def prepare_batch(self, guid_img, img_ori, mask_img_ori, batch_size): | |
input_img = [] | |
cond_img = [] | |
mask_img = [] | |
image_size = [] | |
position_indexes = [] | |
for i in range(batch_size): | |
_input_img, _cond_img, _mask_img, _image_size, _position_indexes = \ | |
self.parse_item(img_ori, mask_img_ori, guid_img) | |
input_img.append(_input_img) | |
cond_img.append(_cond_img) | |
mask_img.append(_mask_img) | |
position_indexes.append(_position_indexes) | |
image_size += [_image_size] * _input_img.shape[0] | |
input_img = torch.cat(input_img, dim=0).to(self.gpu_id) | |
cond_img = torch.cat(cond_img, dim=0).to(self.gpu_id) | |
mask_img = torch.cat(mask_img, dim=0).to(self.gpu_id) | |
position_indexes = torch.cat(position_indexes, dim=0).to(self.gpu_id) | |
return input_img, cond_img, mask_img, image_size, position_indexes | |
def assemble_results(self, img_out, img_hw=None, position_index=None, default_val=1): | |
results_img = np.zeros((img_hw[0], img_hw[1], 3)) | |
weight_img = np.zeros((img_hw[0], img_hw[1], 3)) + 1e-5 | |
for i in range(position_index.shape[0]): | |
# crop out boarder | |
crop_h, crop_w = img_out[i].shape[:2] | |
pathed_img = img_out[i][self.padding_crop:crop_h-self.padding_crop, self.padding_crop:crop_w-self.padding_crop] | |
position_index[i] += self.padding_crop | |
crop_h, crop_w = pathed_img.shape[:2] | |
crop_x, crop_y = max(position_index[i][0], 0), max(position_index[i][1], 0) | |
shape_max = results_img[crop_x:crop_x+crop_h, crop_y:crop_y+crop_w].shape[:2] | |
start_crop_x, start_crop_y = abs(min(position_index[i][0], 0)), abs(min(position_index[i][1], 0)) | |
# print(pathed_img[start_crop_x:shape_max[0], start_crop_y:shape_max[1]].shape, crop_x, crop_y, position_index[i]) | |
results_img[crop_x:crop_x+shape_max[0]-start_crop_x, crop_y:crop_y+shape_max[1]-start_crop_y] += pathed_img[start_crop_x:shape_max[0], start_crop_y:shape_max[1]] | |
weight_img[crop_x:crop_x+crop_h-start_crop_x, crop_y:crop_y+shape_max[1]-start_crop_y] += 1 | |
img_out = results_img / weight_img | |
img_out[weight_img[:,:,0] < 1] = 255 | |
# print(img_out.shape, weight_img.shape, np.unique(weight_img), pathed_img.dtype) | |
img_out_ = (np.zeros((self.ori_hw[0], self.ori_hw[1], 3)) + default_val) * 255 | |
img_out_[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] = img_out | |
img_out = img_out_ | |
return img_out | |
def write_batch_img(self, imgs, image_sizes, position_indexes): | |
cropped_batch = self.split_hw_overlapped[0] * self.split_hw_overlapped[1] | |
if self.results_list is None or self.results_list.shape[0] == 0: | |
self.results_list = imgs | |
self.position_indexes = position_indexes | |
else: | |
self.results_list = torch.cat([self.results_list, imgs], dim=0) | |
self.position_indexes = torch.cat([self.position_indexes, position_indexes], dim=0) | |
self.image_sizes_list += image_sizes | |
valid_len = self.results_list.shape[0] - (self.results_list.shape[0] % cropped_batch) | |
out_images = [] | |
for ind in range(0, valid_len, cropped_batch): | |
# assemble results | |
img_out = (self.results_list[ind:ind+cropped_batch].detach().cpu().numpy() * 255).astype(np.uint8) | |
img_out = self.assemble_results(img_out, self.image_sizes_list[ind], self.position_indexes[ind:ind+cropped_batch].detach().cpu().numpy().astype(int)) | |
# Image.fromarray(img_out.astype(np.uint8)).save(self.results_output_list[ind]) | |
out_images.append(img_out.astype(np.uint8)) | |
self.results_list = self.results_list[valid_len:] | |
self.position_indexes = self.position_indexes[valid_len:] | |
self.image_sizes_list = self.image_sizes_list[valid_len:] | |
return out_images | |
def write_batch_input(self, imgs, image_sizes, position_indexes, default_val=1): | |
cropped_batch = self.split_hw_overlapped[0] * self.split_hw_overlapped[1] | |
images = [] | |
valid_len = imgs.shape[0] | |
for ind in range(0, valid_len, cropped_batch): | |
# assemble results | |
img_out = (imgs[ind:ind+cropped_batch].detach().cpu().numpy() * 255).astype(np.uint8) | |
img_out = self.assemble_results(img_out, image_sizes[ind], position_indexes.detach().cpu().numpy().astype(int), default_val).astype(np.uint8) | |
images.append(img_out) | |
return images | |
def generation(self, split_hw, split_overlap, guid_img, img_ori, mask_img_ori, dps_scale, uc_score, ddim_steps, batch_size=32, n_samples=1): | |
max_batch = __MAX_BATCH__ | |
operator = GaussialBlurOperator(61, 3.0, self.gpu_id) | |
assert batch_size == 1 | |
self.split_resolution = None | |
self.split_overlap = split_overlap | |
self.split_hw = split_hw | |
# get img hw | |
for src_img_id in range(0, 1, batch_size): | |
input_img, cond_img, mask_img, image_sizes, position_indexes = self.prepare_batch(guid_img, img_ori, mask_img_ori, 1) | |
input_masked = self.write_batch_input(cond_img, image_sizes, position_indexes) | |
input_maskes = self.write_batch_input(mask_img, image_sizes, position_indexes, 0) | |
results_all = [] | |
for _ in range(n_samples): | |
for batch_id in range(0, input_img.shape[0], max_batch): | |
embeddings = {} | |
embeddings["cond_img"] = cond_img[batch_id:batch_id+max_batch] | |
if (mask_img[batch_id:batch_id+max_batch] > 0.5).sum() == 0: | |
results = torch.ones_like(cond_img[batch_id:batch_id+max_batch]) | |
else: | |
results = self.model(embeddings, input_img[batch_id:batch_id+max_batch], mask_img[batch_id:batch_id+max_batch], ddim_steps=ddim_steps, | |
guidance_scale=uc_score, dps_scale=dps_scale, as_latent=False, grad_scale=1, operator=operator) | |
out_images = self.write_batch_img(results, image_sizes[batch_id:batch_id+max_batch], position_indexes[batch_id:batch_id+max_batch]) | |
results_all += out_images | |
ret = { | |
"input_image": input_masked, | |
"input_maskes": input_maskes, | |
"out_images": results_all | |
} | |
return ret | |