Spaces:
Runtime error
Runtime error
File size: 8,179 Bytes
d5779bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import torch
class ImageResize:
def __init__(self):
pass
ACTION_TYPE_RESIZE = "resize only"
ACTION_TYPE_CROP = "crop to ratio"
ACTION_TYPE_PAD = "pad to ratio"
RESIZE_MODE_DOWNSCALE = "reduce size only"
RESIZE_MODE_UPSCALE = "increase size only"
RESIZE_MODE_ANY = "any"
RETURN_TYPES = ("IMAGE", "MASK",)
FUNCTION = "resize"
CATEGORY = "image"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"pixels": ("IMAGE",),
"action": ([s.ACTION_TYPE_RESIZE, s.ACTION_TYPE_CROP, s.ACTION_TYPE_PAD],),
"smaller_side": ("INT", {"default": 0, "min": 0, "max": 8192, "step": 8}),
"larger_side": ("INT", {"default": 0, "min": 0, "max": 8192, "step": 8}),
"scale_factor": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"resize_mode": ([s.RESIZE_MODE_DOWNSCALE, s.RESIZE_MODE_UPSCALE, s.RESIZE_MODE_ANY],),
"side_ratio": ("STRING", {"default": "4:3"}),
"crop_pad_position": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"pad_feathering": ("INT", {"default": 20, "min": 0, "max": 8192, "step": 1}),
},
"optional": {
"mask_optional": ("MASK",),
},
}
@classmethod
def VALIDATE_INPUTS(s, action, smaller_side, larger_side, scale_factor, resize_mode, side_ratio, **_):
if side_ratio is not None:
if action != s.ACTION_TYPE_RESIZE and s.parse_side_ratio(side_ratio) is None:
return f"Invalid side ratio: {side_ratio}"
if smaller_side is not None and larger_side is not None and scale_factor is not None:
if int(smaller_side > 0) + int(larger_side > 0) + int(scale_factor > 0) > 1:
return f"At most one scaling rule (smaller_side, larger_side, scale_factor) should be enabled by setting a non-zero value"
if scale_factor is not None:
if resize_mode == s.RESIZE_MODE_DOWNSCALE and scale_factor > 1.0:
return f"For resize_mode {s.RESIZE_MODE_DOWNSCALE}, scale_factor should be less than one but got {scale_factor}"
if resize_mode == s.RESIZE_MODE_UPSCALE and scale_factor > 0.0 and scale_factor < 1.0:
return f"For resize_mode {s.RESIZE_MODE_UPSCALE}, scale_factor should be larger than one but got {scale_factor}"
return True
@classmethod
def parse_side_ratio(s, side_ratio):
try:
x, y = map(int, side_ratio.split(":", 1))
if x < 1 or y < 1:
raise Exception("Ratio factors have to be positive numbers")
return float(x) / float(y)
except:
return None
def resize(self, pixels, action, smaller_side, larger_side, scale_factor, resize_mode, side_ratio, crop_pad_position, pad_feathering, mask_optional=None):
validity = self.VALIDATE_INPUTS(action, smaller_side, larger_side, scale_factor, resize_mode, side_ratio)
if validity is not True:
raise Exception(validity)
height, width = pixels.shape[1:3]
if mask_optional is None:
mask = torch.zeros(1, height, width, dtype=torch.float32)
else:
mask = mask_optional
if mask.shape[1] != height or mask.shape[2] != width:
mask = torch.nn.functional.interpolate(mask.unsqueeze(0), size=(height, width), mode="bicubic").squeeze(0).clamp(0.0, 1.0)
crop_x, crop_y, pad_x, pad_y = (0.0, 0.0, 0.0, 0.0)
if action == self.ACTION_TYPE_CROP:
target_ratio = self.parse_side_ratio(side_ratio)
if height * target_ratio < width:
crop_x = width - height * target_ratio
else:
crop_y = height - width / target_ratio
elif action == self.ACTION_TYPE_PAD:
target_ratio = self.parse_side_ratio(side_ratio)
if height * target_ratio > width:
pad_x = height * target_ratio - width
else:
pad_y = width / target_ratio - height
if smaller_side > 0:
if width + pad_x - crop_x > height + pad_y - crop_y:
scale_factor = float(smaller_side) / (height + pad_y - crop_y)
else:
scale_factor = float(smaller_side) / (width + pad_x - crop_x)
if larger_side > 0:
if width + pad_x - crop_x > height + pad_y - crop_y:
scale_factor = float(larger_side) / (width + pad_x - crop_x)
else:
scale_factor = float(larger_side) / (height + pad_y - crop_y)
if (resize_mode == self.RESIZE_MODE_DOWNSCALE and scale_factor >= 1.0) or (resize_mode == self.RESIZE_MODE_UPSCALE and scale_factor <= 1.0):
scale_factor = 0.0
if scale_factor > 0.0:
pixels = torch.nn.functional.interpolate(pixels.movedim(-1, 1), scale_factor=scale_factor, mode="bicubic", antialias=True).movedim(1, -1).clamp(0.0, 1.0)
mask = torch.nn.functional.interpolate(mask.unsqueeze(0), scale_factor=scale_factor, mode="bicubic", antialias=True).squeeze(0).clamp(0.0, 1.0)
height, width = pixels.shape[1:3]
crop_x *= scale_factor
crop_y *= scale_factor
pad_x *= scale_factor
pad_y *= scale_factor
if crop_x > 0.0 or crop_y > 0.0:
remove_x = (round(crop_x * crop_pad_position), round(crop_x * (1 - crop_pad_position))) if crop_x > 0.0 else (0, 0)
remove_y = (round(crop_y * crop_pad_position), round(crop_y * (1 - crop_pad_position))) if crop_y > 0.0 else (0, 0)
pixels = pixels[:, remove_y[0]:height - remove_y[1], remove_x[0]:width - remove_x[1], :]
mask = mask[:, remove_y[0]:height - remove_y[1], remove_x[0]:width - remove_x[1]]
elif pad_x > 0.0 or pad_y > 0.0:
add_x = (round(pad_x * crop_pad_position), round(pad_x * (1 - crop_pad_position))) if pad_x > 0.0 else (0, 0)
add_y = (round(pad_y * crop_pad_position), round(pad_y * (1 - crop_pad_position))) if pad_y > 0.0 else (0, 0)
new_pixels = torch.zeros(pixels.shape[0], height + add_y[0] + add_y[1], width + add_x[0] + add_x[1], pixels.shape[3], dtype=torch.float32)
new_pixels[:, add_y[0]:height + add_y[0], add_x[0]:width + add_x[0], :] = pixels
pixels = new_pixels
new_mask = torch.ones(mask.shape[0], height + add_y[0] + add_y[1], width + add_x[0] + add_x[1], dtype=torch.float32)
new_mask[:, add_y[0]:height + add_y[0], add_x[0]:width + add_x[0]] = mask
mask = new_mask
if pad_feathering > 0:
for i in range(mask.shape[0]):
for j in range(pad_feathering):
feather_strength = (1 - j / pad_feathering) * (1 - j / pad_feathering)
if add_x[0] > 0 and j < width:
for k in range(height):
mask[i, k, add_x[0] + j] = max(mask[i, k, add_x[0] + j], feather_strength)
if add_x[1] > 0 and j < width:
for k in range(height):
mask[i, k, width + add_x[0] - j - 1] = max(mask[i, k, width + add_x[0] - j - 1], feather_strength)
if add_y[0] > 0 and j < height:
for k in range(width):
mask[i, add_y[0] + j, k] = max(mask[i, add_y[0] + j, k], feather_strength)
if add_y[1] > 0 and j < height:
for k in range(width):
mask[i, height + add_y[0] - j - 1, k] = max(mask[i, height + add_y[0] - j - 1, k], feather_strength)
return (pixels, mask)
NODE_CLASS_MAPPINGS = {
"ImageResize": ImageResize
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageResize": "Image Resize"
}
|