Freak-ppa's picture
Upload 36 files
d5779bb verified
raw
history blame
8.18 kB
import torch
class ImageResize:
def __init__(self):
pass
ACTION_TYPE_RESIZE = "resize only"
ACTION_TYPE_CROP = "crop to ratio"
ACTION_TYPE_PAD = "pad to ratio"
RESIZE_MODE_DOWNSCALE = "reduce size only"
RESIZE_MODE_UPSCALE = "increase size only"
RESIZE_MODE_ANY = "any"
RETURN_TYPES = ("IMAGE", "MASK",)
FUNCTION = "resize"
CATEGORY = "image"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"pixels": ("IMAGE",),
"action": ([s.ACTION_TYPE_RESIZE, s.ACTION_TYPE_CROP, s.ACTION_TYPE_PAD],),
"smaller_side": ("INT", {"default": 0, "min": 0, "max": 8192, "step": 8}),
"larger_side": ("INT", {"default": 0, "min": 0, "max": 8192, "step": 8}),
"scale_factor": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"resize_mode": ([s.RESIZE_MODE_DOWNSCALE, s.RESIZE_MODE_UPSCALE, s.RESIZE_MODE_ANY],),
"side_ratio": ("STRING", {"default": "4:3"}),
"crop_pad_position": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"pad_feathering": ("INT", {"default": 20, "min": 0, "max": 8192, "step": 1}),
},
"optional": {
"mask_optional": ("MASK",),
},
}
@classmethod
def VALIDATE_INPUTS(s, action, smaller_side, larger_side, scale_factor, resize_mode, side_ratio, **_):
if side_ratio is not None:
if action != s.ACTION_TYPE_RESIZE and s.parse_side_ratio(side_ratio) is None:
return f"Invalid side ratio: {side_ratio}"
if smaller_side is not None and larger_side is not None and scale_factor is not None:
if int(smaller_side > 0) + int(larger_side > 0) + int(scale_factor > 0) > 1:
return f"At most one scaling rule (smaller_side, larger_side, scale_factor) should be enabled by setting a non-zero value"
if scale_factor is not None:
if resize_mode == s.RESIZE_MODE_DOWNSCALE and scale_factor > 1.0:
return f"For resize_mode {s.RESIZE_MODE_DOWNSCALE}, scale_factor should be less than one but got {scale_factor}"
if resize_mode == s.RESIZE_MODE_UPSCALE and scale_factor > 0.0 and scale_factor < 1.0:
return f"For resize_mode {s.RESIZE_MODE_UPSCALE}, scale_factor should be larger than one but got {scale_factor}"
return True
@classmethod
def parse_side_ratio(s, side_ratio):
try:
x, y = map(int, side_ratio.split(":", 1))
if x < 1 or y < 1:
raise Exception("Ratio factors have to be positive numbers")
return float(x) / float(y)
except:
return None
def resize(self, pixels, action, smaller_side, larger_side, scale_factor, resize_mode, side_ratio, crop_pad_position, pad_feathering, mask_optional=None):
validity = self.VALIDATE_INPUTS(action, smaller_side, larger_side, scale_factor, resize_mode, side_ratio)
if validity is not True:
raise Exception(validity)
height, width = pixels.shape[1:3]
if mask_optional is None:
mask = torch.zeros(1, height, width, dtype=torch.float32)
else:
mask = mask_optional
if mask.shape[1] != height or mask.shape[2] != width:
mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=(height, width), mode="bicubic").squeeze(0).clamp(0.0, 1.0)
crop_x, crop_y, pad_x, pad_y = (0.0, 0.0, 0.0, 0.0)
if action == self.ACTION_TYPE_CROP:
target_ratio = self.parse_side_ratio(side_ratio)
if height * target_ratio < width:
crop_x = width - height * target_ratio
else:
crop_y = height - width / target_ratio
elif action == self.ACTION_TYPE_PAD:
target_ratio = self.parse_side_ratio(side_ratio)
if height * target_ratio > width:
pad_x = height * target_ratio - width
else:
pad_y = width / target_ratio - height
if smaller_side > 0:
if width + pad_x - crop_x > height + pad_y - crop_y:
scale_factor = float(smaller_side) / (height + pad_y - crop_y)
else:
scale_factor = float(smaller_side) / (width + pad_x - crop_x)
if larger_side > 0:
if width + pad_x - crop_x > height + pad_y - crop_y:
scale_factor = float(larger_side) / (width + pad_x - crop_x)
else:
scale_factor = float(larger_side) / (height + pad_y - crop_y)
if (resize_mode == self.RESIZE_MODE_DOWNSCALE and scale_factor >= 1.0) or (resize_mode == self.RESIZE_MODE_UPSCALE and scale_factor <= 1.0):
scale_factor = 0.0
if scale_factor > 0.0:
pixels = torch.nn.functional.interpolate(pixels.movedim(-1, 1), scale_factor=scale_factor, mode="bicubic", antialias=True).movedim(1, -1).clamp(0.0, 1.0)
mask = torch.nn.functional.interpolate(mask.unsqueeze(0), scale_factor=scale_factor, mode="bicubic", antialias=True).squeeze(0).clamp(0.0, 1.0)
height, width = pixels.shape[1:3]
crop_x *= scale_factor
crop_y *= scale_factor
pad_x *= scale_factor
pad_y *= scale_factor
if crop_x > 0.0 or crop_y > 0.0:
remove_x = (round(crop_x * crop_pad_position), round(crop_x * (1 - crop_pad_position))) if crop_x > 0.0 else (0, 0)
remove_y = (round(crop_y * crop_pad_position), round(crop_y * (1 - crop_pad_position))) if crop_y > 0.0 else (0, 0)
pixels = pixels[:, remove_y[0]:height - remove_y[1], remove_x[0]:width - remove_x[1], :]
mask = mask[:, remove_y[0]:height - remove_y[1], remove_x[0]:width - remove_x[1]]
elif pad_x > 0.0 or pad_y > 0.0:
add_x = (round(pad_x * crop_pad_position), round(pad_x * (1 - crop_pad_position))) if pad_x > 0.0 else (0, 0)
add_y = (round(pad_y * crop_pad_position), round(pad_y * (1 - crop_pad_position))) if pad_y > 0.0 else (0, 0)
new_pixels = torch.zeros(pixels.shape[0], height + add_y[0] + add_y[1], width + add_x[0] + add_x[1], pixels.shape[3], dtype=torch.float32)
new_pixels[:, add_y[0]:height + add_y[0], add_x[0]:width + add_x[0], :] = pixels
pixels = new_pixels
new_mask = torch.ones(mask.shape[0], height + add_y[0] + add_y[1], width + add_x[0] + add_x[1], dtype=torch.float32)
new_mask[:, add_y[0]:height + add_y[0], add_x[0]:width + add_x[0]] = mask
mask = new_mask
if pad_feathering > 0:
for i in range(mask.shape[0]):
for j in range(pad_feathering):
feather_strength = (1 - j / pad_feathering) * (1 - j / pad_feathering)
if add_x[0] > 0 and j < width:
for k in range(height):
mask[i, k, add_x[0] + j] = max(mask[i, k, add_x[0] + j], feather_strength)
if add_x[1] > 0 and j < width:
for k in range(height):
mask[i, k, width + add_x[0] - j - 1] = max(mask[i, k, width + add_x[0] - j - 1], feather_strength)
if add_y[0] > 0 and j < height:
for k in range(width):
mask[i, add_y[0] + j, k] = max(mask[i, add_y[0] + j, k], feather_strength)
if add_y[1] > 0 and j < height:
for k in range(width):
mask[i, height + add_y[0] - j - 1, k] = max(mask[i, height + add_y[0] - j - 1, k], feather_strength)
return (pixels, mask)
NODE_CLASS_MAPPINGS = {
"ImageResize": ImageResize
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageResize": "Image Resize"
}