IntrinsicAnything / inference.py
burningdust
Initial commit
d72c37e
import os
import imageio
import numpy as np
import glob
import sys
from typing import Any
sys.path.insert(1, '.')
import argparse
from pytorch_lightning import seed_everything
from PIL import Image
import torch
from operators import GaussialBlurOperator
from utils import get_rank
from torchvision.ops import masks_to_boxes
from matfusion import MateralDiffusion
from loguru import logger
__MAX_BATCH__ = 4 # 4 for A10
def init_model(ckpt_path, ddim, gpu_id):
# find config
configs = os.listdir(f'{ckpt_path}/configs')
model_config = [config for config in configs if "project.yaml" in config][0]
sds_loss_class = MateralDiffusion(device=gpu_id, fp16=True,
config=f'{ckpt_path}/configs/{model_config}',
ckpt=f'{ckpt_path}/checkpoints/last.ckpt', vram_O=False,
t_range=[0.001, 0.02], opt=None, use_ddim=ddim)
return sds_loss_class
def images_spliter(image, seg_h, seg_w, padding_pixel, padding_val, overlaps=1):
# split the input images along height and weidth by
# return a list of images
h, w, c = image.shape
h = h - (h%(seg_h*overlaps))
w = w - (w%(seg_w*overlaps))
h_crop = h // seg_h
w_crop = w // seg_w
images = []
positions = []
img_padded = torch.zeros(h+padding_pixel*2, w+padding_pixel*2, 3, device=image.device) + padding_val
img_padded[padding_pixel:h+padding_pixel, padding_pixel:w+padding_pixel, :] = image[:h, :w]
# overlapped sampling
seg_h = np.round((h - h_crop) / h_crop * overlaps).astype(int) + 1
seg_w = np.round((w - w_crop) / w_crop * overlaps).astype(int) + 1
h_step = np.round(h_crop / overlaps).astype(int)
w_step = np.round(w_crop / overlaps).astype(int)
# print(f"h_step: {h_step}, seg_h: {seg_h}, w_step: {w_step}, seg_w: {seg_w}, img_padded: {img_padded.shape}, image[:h, :w]: {image[:h, :w].shape}")
for ind_i in range(0,seg_h):
i = ind_i * h_step
for ind_j in range(0,seg_w):
j = ind_j * w_step
img_ = img_padded[i:i+h_crop+padding_pixel*2, j:j+w_crop+padding_pixel*2, :]
images.append(img_)
positions.append(torch.FloatTensor([i-padding_pixel, j-padding_pixel]).reshape(2))
return torch.stack(images, dim=0), torch.stack(positions, dim=0), seg_h, seg_w
class InferenceModel():
def __init__(self, ckpt_path, use_ddim, gpu_id=0):
self.model = init_model(ckpt_path, use_ddim, gpu_id=gpu_id)
self.gpu_id = gpu_id
self.split_hw = [1,1]
self.padding = 0
self.padding_crop = 0
self.results_list = None
self.results_output_list = []
self.image_sizes_list = []
def parse_item(self, img_ori, mask_img_ori, guid_images):
# if mask_img_ori is None:
# mask_img_ori = read_img(input_name, read_alpha=True)
# # ensure background is white, same as training data
# img_ori[~(mask_img_ori[..., 0] > 0.5)] = 1
img_ori[~(mask_img_ori[..., 0] > 0.5)] = 1
use_true_mask = (self.split_hw[0] * self.split_hw[1]) <= 1
self.ori_hw = list(img_ori.shape)
# mask cropping
min_max_uv = masks_to_boxes(mask_img_ori[None, ..., -1] > 0.5).long()
self.min_uv, self.max_uv = min_max_uv[0, ..., [1,0]], min_max_uv[0, ..., [3,2]]+1
# print(self.min_uv, self.max_uv)
mask_img = mask_img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]]
img = img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]]
image_size = list(img.shape)
if not use_true_mask:
# for cropping boarder
self.max_uv[0] = self.max_uv[0] - ((self.max_uv[0]-self.min_uv[0])%(self.split_hw[0]*self.split_overlap))
self.max_uv[1] = self.max_uv[1] - ((self.max_uv[1]-self.min_uv[1])%(self.split_hw[1]*self.split_overlap))
mask_img = mask_img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]]
img = img_ori[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]]
image_size = list(img.shape)
if not use_true_mask:
mask_img = torch.ones_like(mask_img)
mask_img, _ = images_spliter(mask_img[..., [0, 0, 0]], self.split_hw[0], self.split_hw[1], self.padding, not use_true_mask, self.split_overlap)[:2]
img, position_indexes, seg_h, seg_w = images_spliter(img, self.split_hw[0], self.split_hw[1], self.padding, 1, self.split_overlap)
self.split_hw_overlapped = [seg_h, seg_w]
logger.info(f"Spliting Size: {image_size}, splits: {self.split_hw}, Overlapped: {self.split_hw_overlapped}")
if guid_images is None:
guid_images = torch.zeros_like(img)
else:
guid_images = guid_images[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]]
guid_images, _ = images_spliter(guid_images, self.split_hw[0], self.split_hw[1], self.padding, 1, self.split_overlap)[:2]
return guid_images, img, mask_img[..., :1], image_size, position_indexes
def prepare_batch(self, guid_img, img_ori, mask_img_ori, batch_size):
input_img = []
cond_img = []
mask_img = []
image_size = []
position_indexes = []
for i in range(batch_size):
_input_img, _cond_img, _mask_img, _image_size, _position_indexes = \
self.parse_item(img_ori, mask_img_ori, guid_img)
input_img.append(_input_img)
cond_img.append(_cond_img)
mask_img.append(_mask_img)
position_indexes.append(_position_indexes)
image_size += [_image_size] * _input_img.shape[0]
input_img = torch.cat(input_img, dim=0).to(self.gpu_id)
cond_img = torch.cat(cond_img, dim=0).to(self.gpu_id)
mask_img = torch.cat(mask_img, dim=0).to(self.gpu_id)
position_indexes = torch.cat(position_indexes, dim=0).to(self.gpu_id)
return input_img, cond_img, mask_img, image_size, position_indexes
def assemble_results(self, img_out, img_hw=None, position_index=None, default_val=1):
results_img = np.zeros((img_hw[0], img_hw[1], 3))
weight_img = np.zeros((img_hw[0], img_hw[1], 3)) + 1e-5
for i in range(position_index.shape[0]):
# crop out boarder
crop_h, crop_w = img_out[i].shape[:2]
pathed_img = img_out[i][self.padding_crop:crop_h-self.padding_crop, self.padding_crop:crop_w-self.padding_crop]
position_index[i] += self.padding_crop
crop_h, crop_w = pathed_img.shape[:2]
crop_x, crop_y = max(position_index[i][0], 0), max(position_index[i][1], 0)
shape_max = results_img[crop_x:crop_x+crop_h, crop_y:crop_y+crop_w].shape[:2]
start_crop_x, start_crop_y = abs(min(position_index[i][0], 0)), abs(min(position_index[i][1], 0))
# print(pathed_img[start_crop_x:shape_max[0], start_crop_y:shape_max[1]].shape, crop_x, crop_y, position_index[i])
results_img[crop_x:crop_x+shape_max[0]-start_crop_x, crop_y:crop_y+shape_max[1]-start_crop_y] += pathed_img[start_crop_x:shape_max[0], start_crop_y:shape_max[1]]
weight_img[crop_x:crop_x+crop_h-start_crop_x, crop_y:crop_y+shape_max[1]-start_crop_y] += 1
img_out = results_img / weight_img
img_out[weight_img[:,:,0] < 1] = 255
# print(img_out.shape, weight_img.shape, np.unique(weight_img), pathed_img.dtype)
img_out_ = (np.zeros((self.ori_hw[0], self.ori_hw[1], 3)) + default_val) * 255
img_out_[self.min_uv[0]:self.max_uv[0], self.min_uv[1]:self.max_uv[1]] = img_out
img_out = img_out_
return img_out
def write_batch_img(self, imgs, image_sizes, position_indexes):
cropped_batch = self.split_hw_overlapped[0] * self.split_hw_overlapped[1]
if self.results_list is None or self.results_list.shape[0] == 0:
self.results_list = imgs
self.position_indexes = position_indexes
else:
self.results_list = torch.cat([self.results_list, imgs], dim=0)
self.position_indexes = torch.cat([self.position_indexes, position_indexes], dim=0)
self.image_sizes_list += image_sizes
valid_len = self.results_list.shape[0] - (self.results_list.shape[0] % cropped_batch)
out_images = []
for ind in range(0, valid_len, cropped_batch):
# assemble results
img_out = (self.results_list[ind:ind+cropped_batch].detach().cpu().numpy() * 255).astype(np.uint8)
img_out = self.assemble_results(img_out, self.image_sizes_list[ind], self.position_indexes[ind:ind+cropped_batch].detach().cpu().numpy().astype(int))
# Image.fromarray(img_out.astype(np.uint8)).save(self.results_output_list[ind])
out_images.append(img_out.astype(np.uint8))
self.results_list = self.results_list[valid_len:]
self.position_indexes = self.position_indexes[valid_len:]
self.image_sizes_list = self.image_sizes_list[valid_len:]
return out_images
def write_batch_input(self, imgs, image_sizes, position_indexes, default_val=1):
cropped_batch = self.split_hw_overlapped[0] * self.split_hw_overlapped[1]
images = []
valid_len = imgs.shape[0]
for ind in range(0, valid_len, cropped_batch):
# assemble results
img_out = (imgs[ind:ind+cropped_batch].detach().cpu().numpy() * 255).astype(np.uint8)
img_out = self.assemble_results(img_out, image_sizes[ind], position_indexes.detach().cpu().numpy().astype(int), default_val).astype(np.uint8)
images.append(img_out)
return images
def generation(self, split_hw, split_overlap, guid_img, img_ori, mask_img_ori, dps_scale, uc_score, ddim_steps, batch_size=32, n_samples=1):
max_batch = __MAX_BATCH__
operator = GaussialBlurOperator(61, 3.0, self.gpu_id)
assert batch_size == 1
self.split_resolution = None
self.split_overlap = split_overlap
self.split_hw = split_hw
# get img hw
for src_img_id in range(0, 1, batch_size):
input_img, cond_img, mask_img, image_sizes, position_indexes = self.prepare_batch(guid_img, img_ori, mask_img_ori, 1)
input_masked = self.write_batch_input(cond_img, image_sizes, position_indexes)
input_maskes = self.write_batch_input(mask_img, image_sizes, position_indexes, 0)
results_all = []
for _ in range(n_samples):
for batch_id in range(0, input_img.shape[0], max_batch):
embeddings = {}
embeddings["cond_img"] = cond_img[batch_id:batch_id+max_batch]
if (mask_img[batch_id:batch_id+max_batch] > 0.5).sum() == 0:
results = torch.ones_like(cond_img[batch_id:batch_id+max_batch])
else:
results = self.model(embeddings, input_img[batch_id:batch_id+max_batch], mask_img[batch_id:batch_id+max_batch], ddim_steps=ddim_steps,
guidance_scale=uc_score, dps_scale=dps_scale, as_latent=False, grad_scale=1, operator=operator)
out_images = self.write_batch_img(results, image_sizes[batch_id:batch_id+max_batch], position_indexes[batch_id:batch_id+max_batch])
results_all += out_images
ret = {
"input_image": input_masked,
"input_maskes": input_maskes,
"out_images": results_all
}
return ret