TriplaneGaussian / utils.py
zouzx's picture
update gradio theme
c839178
raw
history blame contribute delete
No virus
3.65 kB
import os
import time
import cv2
import numpy as np
import torch
from PIL import Image
from rembg import remove
from segment_anything import SamPredictor, sam_model_registry
import urllib.request
from tqdm import tqdm
def sam_init(sam_checkpoint, device_id=0):
model_type = "vit_h"
device = "cuda:{}".format(device_id) if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device)
predictor = SamPredictor(sam)
return predictor
def sam_out_nosave(predictor, input_image, *bbox_sliders):
bbox = np.array(bbox_sliders)
image = np.asarray(input_image)
predictor.set_image(image)
masks_bbox, scores_bbox, logits_bbox = predictor.predict(
box=bbox, multimask_output=True
)
out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
out_image[:, :, :3] = image
out_image_bbox = out_image.copy()
out_image_bbox[:, :, 3] = (
masks_bbox[-1].astype(np.uint8) * 255
) # np.argmax(scores_bbox)
torch.cuda.empty_cache()
return Image.fromarray(out_image_bbox, mode="RGBA")
# contrast correction, rescale and recenter
def image_preprocess(input_image, save_path, lower_contrast=True, rescale=True):
image_arr = np.array(input_image)
in_w, in_h = image_arr.shape[:2]
if lower_contrast:
alpha = 0.8 # Contrast control (1.0-3.0)
beta = 0 # Brightness control (0-100)
# Apply the contrast adjustment
image_arr = cv2.convertScaleAbs(image_arr, alpha=alpha, beta=beta)
image_arr[image_arr[..., -1] > 200, -1] = 255
ret, mask = cv2.threshold(
np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY
)
x, y, w, h = cv2.boundingRect(mask)
max_size = max(w, h)
ratio = 0.75
if rescale:
side_len = int(max_size / ratio)
else:
side_len = in_w
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
center = side_len // 2
padded_image[
center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w
] = image_arr[y : y + h, x : x + w]
rgba = Image.fromarray(padded_image).resize((256, 256), Image.LANCZOS)
rgba.save(save_path)
def pred_bbox(image):
image_nobg = remove(image.convert("RGBA"), alpha_matting=True)
alpha = np.asarray(image_nobg)[:, :, -1]
x_nonzero = np.nonzero(alpha.sum(axis=0))
y_nonzero = np.nonzero(alpha.sum(axis=1))
x_min = int(x_nonzero[0].min())
y_min = int(y_nonzero[0].min())
x_max = int(x_nonzero[0].max())
y_max = int(y_nonzero[0].max())
return x_min, y_min, x_max, y_max
# convert a function into recursive style to handle nested dict/list/tuple variables
def make_recursive_func(func):
def wrapper(vars, *args, **kwargs):
if isinstance(vars, list):
return [wrapper(x, *args, **kwargs) for x in vars]
elif isinstance(vars, tuple):
return tuple([wrapper(x, *args, **kwargs) for x in vars])
elif isinstance(vars, dict):
return {k: wrapper(v, *args, **kwargs) for k, v in vars.items()}
else:
return func(vars, *args, **kwargs)
return wrapper
@make_recursive_func
def todevice(vars, device="cuda"):
if isinstance(vars, torch.Tensor):
return vars.to(device)
elif isinstance(vars, str):
return vars
elif isinstance(vars, bool):
return vars
elif isinstance(vars, float):
return vars
elif isinstance(vars, int):
return vars
else:
raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars)))