lev1's picture
Initial commit
8fd2f2f
import io
import math
import os
import PIL.Image
import numpy as np
import imageio.v3 as iio
import warnings
from torchvision.utils import flow_to_image
import torch
import torchvision.transforms.functional as TF
from scipy.ndimage import binary_dilation, binary_erosion
import cv2
from ..animation import Animation
from .. import config
from .. import libimage
import re
def torch2np(x, vmin=-1, vmax=1):
if x.ndim != 4:
# raise Exception("Please only use (B,C,H,W) torch tensors!")
warnings.warn(
"Warning! Shape of the image was not provided in (B,C,H,W) format, the shape was inferred automatically!")
if x.ndim == 3:
x = x[None]
if x.ndim == 2:
x = x[None, None]
assert x.shape[1] == 3 or x.shape[1] == 1
x = x.detach().cpu().float()
if x.dtype == torch.uint8:
return x.numpy().astype(np.uint8)
elif vmin is not None and vmax is not None:
x = (255 * (x.clip(vmin, vmax) - vmin) / (vmax - vmin))
x = x.permute(0, 2, 3, 1).to(torch.uint8)
return x.numpy()
else:
raise NotImplementedError()
class IImage:
'''
Generic media storage. Can store both images and videos.
Stores data as a numpy array by default.
Can be viewed in a jupyter notebook.
'''
@staticmethod
def open(path):
iio_obj = iio.imopen(path, 'r')
data = iio_obj.read()
try:
# .properties() does not work for images but for gif files
if not iio_obj.properties().is_batch:
data = data[None]
except AttributeError as e:
# this one works for gif files
if not "duration" in iio_obj.metadata():
data = data[None]
if data.ndim == 3:
data = data[..., None]
image = IImage(data)
image.link = os.path.abspath(path)
return image
@staticmethod
def flow_field(flow):
flow_images = flow_to_image(flow)
return IImage(flow_images, vmin=0, vmax=255)
@staticmethod
def normalized(x, dims=[-1, -2]):
x = (x - x.amin(dims, True)) / \
(x.amax(dims, True) - x.amin(dims, True))
return IImage(x, 0)
def numpy(self): return self.data
def torch(self, vmin=-1, vmax=1):
if self.data.ndim == 3:
data = self.data.transpose(2, 0, 1) / 255.
else:
data = self.data.transpose(0, 3, 1, 2) / 255.
return vmin + torch.from_numpy(data).float().to(self.device) * (vmax - vmin)
def cuda(self):
self.device = 'cuda'
return self
def cpu(self):
self.device = 'cpu'
return self
def pil(self):
ans = []
for x in self.data:
if x.shape[-1] == 1:
x = x[..., 0]
ans.append(PIL.Image.fromarray(x))
if len(ans) == 1:
return ans[0]
return ans
def is_iimage(self):
return True
@property
def shape(self): return self.data.shape
@property
def size(self): return (self.data.shape[-2], self.data.shape[-3])
def setFps(self, fps):
self.fps = fps
self.generate_display()
return self
def __init__(self, x, vmin=-1, vmax=1, fps=None):
if isinstance(x, PIL.Image.Image):
self.data = np.array(x)
if self.data.ndim == 2:
self.data = self.data[..., None] # (H,W,C)
self.data = self.data[None] # (B,H,W,C)
elif isinstance(x, IImage):
self.data = x.data.copy() # Simple Copy
elif isinstance(x, np.ndarray):
self.data = x.copy().astype(np.uint8)
if self.data.ndim == 2:
self.data = self.data[None, ..., None]
if self.data.ndim == 3:
warnings.warn(
"Inferred dimensions for a 3D array as (H,W,C), but could've been (B,H,W)")
self.data = self.data[None]
elif isinstance(x, torch.Tensor):
assert x.min() >= vmin and x.max(
) <= vmax, f"input data was [{x.min()},{x.max()}], but expected [{vmin},{vmax}]"
self.data = torch2np(x, vmin, vmax)
self.display_str = None
self.device = 'cpu'
self.fps = fps if fps is not None else (
1 if len(self.data) < 10 else 30)
self.link = None
def generate_display(self):
if config.IMG_THUMBSIZE is not None:
if self.size[1] < self.size[0]:
thumb = self.resize(
(self.size[1]*config.IMG_THUMBSIZE//self.size[0], config.IMG_THUMBSIZE))
else:
thumb = self.resize(
(config.IMG_THUMBSIZE, self.size[0]*config.IMG_THUMBSIZE//self.size[1]))
else:
thumb = self
if self.is_video():
self.anim = Animation(thumb.data, fps=self.fps)
self.anim.render()
self.display_str = self.anim.anim_str
else:
b = io.BytesIO()
data = thumb.data[0]
if data.shape[-1] == 1:
data = data[..., 0]
PIL.Image.fromarray(data).save(b, "PNG")
self.display_str = b.getvalue()
return self.display_str
def resize(self, size, *args, **kwargs):
if size is None:
return self
use_small_edge_when_int = kwargs.pop('use_small_edge_when_int', False)
# Backward compatibility
resample = kwargs.pop('filter', PIL.Image.BICUBIC)
resample = kwargs.pop('resample', resample)
if isinstance(size, int):
if use_small_edge_when_int:
h, w = self.data.shape[1:3]
aspect_ratio = h / w
size = (max(size, int(size * aspect_ratio)),
max(size, int(size / aspect_ratio)))
else:
h, w = self.data.shape[1:3]
aspect_ratio = h / w
size = (min(size, int(size * aspect_ratio)),
min(size, int(size / aspect_ratio)))
if self.size == size[::-1]:
return self
return libimage.stack([IImage(x.pil().resize(size[::-1], *args, resample=resample, **kwargs)) for x in self])
# return IImage(TF.resize(self.cpu().torch(0), size, *args, **kwargs), 0)
def pad(self, padding, *args, **kwargs):
return IImage(TF.pad(self.torch(0), padding=padding, *args, **kwargs), 0)
def padx(self, multiplier, *args, **kwargs):
size = np.array(self.size)
padding = np.concatenate(
[[0, 0], np.ceil(size / multiplier).astype(int) * multiplier - size])
return self.pad(list(padding), *args, **kwargs)
def pad2wh(self, w=0, h=0, **kwargs):
cw, ch = self.size
return self.pad([0, 0, max(0, w - cw), max(0, h-ch)], **kwargs)
def pad2square(self, *args, **kwargs):
if self.size[0] > self.size[1]:
dx = self.size[0] - self.size[1]
return self.pad([0, dx//2, 0, dx-dx//2], *args, **kwargs)
elif self.size[0] < self.size[1]:
dx = self.size[1] - self.size[0]
return self.pad([dx//2, 0, dx-dx//2, 0], *args, **kwargs)
return self
def crop2square(self, *args, **kwargs):
if self.size[0] > self.size[1]:
dx = self.size[0] - self.size[1]
return self.crop([dx//2, 0, self.size[1], self.size[1]], *args, **kwargs)
elif self.size[0] < self.size[1]:
dx = self.size[1] - self.size[0]
return self.crop([0, dx//2, self.size[0], self.size[0]], *args, **kwargs)
return self
def alpha(self):
return IImage(self.data[..., -1, None], fps=self.fps)
def rgb(self):
return IImage(self.pil().convert('RGB'), fps=self.fps)
def png(self):
return IImage(np.concatenate([self.data, 255 * np.ones_like(self.data)[..., :1]], -1))
def grid(self, nrows=None, ncols=None):
if nrows is not None:
ncols = math.ceil(self.data.shape[0] / nrows)
elif ncols is not None:
nrows = math.ceil(self.data.shape[0] / ncols)
else:
warnings.warn(
"No dimensions specified, creating a grid with 5 columns (default)")
ncols = 5
nrows = math.ceil(self.data.shape[0] / ncols)
pad = nrows * ncols - self.data.shape[0]
data = np.pad(self.data, ((0, pad), (0, 0), (0, 0), (0, 0)))
rows = [np.concatenate(x, 1, dtype=np.uint8)
for x in np.array_split(data, nrows)]
return IImage(np.concatenate(rows, 0, dtype=np.uint8)[None])
def hstack(self):
return IImage(np.concatenate(self.data, 1, dtype=np.uint8)[None])
def vstack(self):
return IImage(np.concatenate(self.data, 0, dtype=np.uint8)[None])
def vsplit(self, number_of_splits):
return IImage(np.concatenate(np.split(self.data, number_of_splits, 1)))
def hsplit(self, number_of_splits):
return IImage(np.concatenate(np.split(self.data, number_of_splits, 2)))
def heatmap(self, resize=None, cmap=cv2.COLORMAP_JET):
data = np.stack([cv2.cvtColor(cv2.applyColorMap(
x, cmap), cv2.COLOR_BGR2RGB) for x in self.data])
return IImage(data).resize(resize, use_small_edge_when_int=True)
def display(self):
try:
display(self)
except:
print("No display")
return self
def dilate(self, iterations=1, *args, **kwargs):
if iterations == 0:
return IImage(self.data)
return IImage((binary_dilation(self.data, iterations=iterations, *args, *kwargs)*255.).astype(np.uint8))
def erode(self, iterations=1, *args, **kwargs):
return IImage((binary_erosion(self.data, iterations=iterations, *args, *kwargs)*255.).astype(np.uint8))
def hull(self):
convex_hulls = []
for frame in self.data:
contours, hierarchy = cv2.findContours(
frame, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
contours = [x.astype(np.int32) for x in contours]
mask_contours = [cv2.convexHull(np.concatenate(contours))]
canvas = np.zeros(self.data[0].shape, np.uint8)
convex_hull = cv2.drawContours(
canvas, mask_contours, -1, (255, 0, 0), -1)
convex_hulls.append(convex_hull)
return IImage(np.array(convex_hulls))
def is_video(self):
return self.data.shape[0] > 1
def __getitem__(self, idx):
return IImage(self.data[None, idx], fps=self.fps)
# if self.is_video(): return IImage(self.data[idx], fps = self.fps)
# return self
def _repr_png_(self):
if self.is_video():
return None
if self.display_str is None:
self.generate_display()
return self.display_str
def _repr_html_(self):
if not self.is_video():
return None
if self.display_str is None:
self.generate_display()
return self.display_str
def save(self, path):
_, ext = os.path.splitext(path)
if self.is_video():
# if ext in ['.jpg', '.png']:
if self.display_str is None:
self.generate_display()
if ext == ".apng":
self.anim.anim_obj.save(path, writer="pillow")
else:
self.anim.anim_obj.save(path)
else:
data = self.data if self.data.ndim == 3 else self.data[0]
if data.shape[-1] == 1:
data = data[:, :, 0]
PIL.Image.fromarray(data).save(path)
return self
def to_html(self, width='auto', root_path='/'):
if self.display_str is None:
self.generate_display()
# print (self.display_str)
html_tag = bytes2html(self.display_str, width=width)
if self.link is not None:
link = os.path.relpath(self.link, root_path)
return f'<a href="{link}" >{html_tag}</a>'
return html_tag
def write(self, text, center=(0, 25), font_scale=0.8, color=(255, 255, 255), thickness=2):
if not isinstance(text, list):
text = [text for _ in self.data]
data = np.stack([cv2.putText(x.copy(), t, center, cv2.FONT_HERSHEY_COMPLEX,
font_scale, color, thickness) for x, t in zip(self.data, text)])
return IImage(data)
def append_text(self, text, padding, font_scale=0.8, color=(255, 255, 255), thickness=2, scale_factor=0.9, center=(0, 0), fill=0):
assert np.count_nonzero(padding) == 1
axis_padding = np.nonzero(padding)[0][0]
scale_padding = padding[axis_padding]
y_0 = 0
x_0 = 0
if axis_padding == 0:
width = scale_padding
y_max = self.shape[1]
elif axis_padding == 1:
width = self.shape[2]
y_max = scale_padding
elif axis_padding == 2:
x_0 = self.shape[2]
width = scale_padding
y_max = self.shape[1]
elif axis_padding == 3:
width = self.shape[2]
y_0 = self.shape[1]
y_max = self.shape[1]+scale_padding
width -= center[0]
x_0 += center[0]
y_0 += center[1]
self = self.pad(padding, fill=fill)
def wrap_text(text, width, _font_scale):
allowed_seperator = ' |-|_|/|\n'
words = re.split(allowed_seperator, text)
# words = text.split()
lines = []
current_line = words[0]
sep_list = []
start_idx = 0
for start_word in words[:-1]:
pos = text.find(start_word, start_idx)
pos += len(start_word)
sep_list.append(text[pos])
start_idx = pos+1
for word, separator in zip(words[1:], sep_list):
if cv2.getTextSize(current_line + separator + word, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width:
current_line += separator + word
else:
if cv2.getTextSize(current_line, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width:
lines.append(current_line)
current_line = word
else:
return []
if cv2.getTextSize(current_line, cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][0] <= width:
lines.append(current_line)
else:
return []
return lines
def wrap_text_and_scale(text, width, _font_scale, y_0, y_max):
height = y_max+1
while height > y_max:
text_lines = wrap_text(text, width, _font_scale)
if len(text) > 0 and len(text_lines) == 0:
height = y_max+1
else:
line_height = cv2.getTextSize(
text_lines[0], cv2.FONT_HERSHEY_COMPLEX, _font_scale, thickness)[0][1]
height = line_height * len(text_lines) + y_0
# scale font if out of frame
if height > y_max:
_font_scale = _font_scale * scale_factor
return text_lines, line_height, _font_scale
result = []
if not isinstance(text, list):
text = [text for _ in self.data]
else:
assert len(text) == len(self.data)
for x, t in zip(self.data, text):
x = x.copy()
text_lines, line_height, _font_scale = wrap_text_and_scale(
t, width, font_scale, y_0, y_max)
y = line_height
for line in text_lines:
x = cv2.putText(
x, line, (x_0, y_0+y), cv2.FONT_HERSHEY_COMPLEX, _font_scale, color, thickness)
y += line_height
result.append(x)
data = np.stack(result)
return IImage(data)
# ========== OPERATORS =============
def __or__(self, other):
# TODO: fix for variable sizes
return IImage(np.concatenate([self.data, other.data], 2))
def __truediv__(self, other):
# TODO: fix for variable sizes
return IImage(np.concatenate([self.data, other.data], 1))
def __and__(self, other):
return IImage(np.concatenate([self.data, other.data], 0))
def __add__(self, other):
return IImage(0.5 * self.data + 0.5 * other.data)
def __mul__(self, other):
if isinstance(other, IImage):
return IImage(self.data / 255. * other.data)
return IImage(self.data * other / 255.)
def __xor__(self, other):
return IImage(0.5 * self.data + 0.5 * other.data + 0.5 * self.data * (other.data.sum(-1, keepdims=True) == 0))
def __invert__(self):
return IImage(255 - self.data)
__rmul__ = __mul__
def bbox(self):
return [cv2.boundingRect(x) for x in self.data]
def fill_bbox(self, bbox_list, fill=255):
data = self.data.copy()
for bbox in bbox_list:
x, y, w, h = bbox
data[:, y:y+h, x:x+w, :] = fill
return IImage(data)
def crop(self, bbox):
assert len(bbox) in [2, 4]
if len(bbox) == 2:
x, y = 0, 0
w, h = bbox
elif len(bbox) == 4:
x, y, w, h = bbox
return IImage(self.data[:, y:y+h, x:x+w, :])
# def alpha(self):
# return BetterImage(self.img.split()[-1])
# def resize(self, size, *args, **kwargs):
# if size is None: return self
# return BetterImage(TF.resize(self.img, size, *args, **kwargs))
# def pad(self, *args):
# return BetterImage(TF.pad(self.img, *args))
# def padx(self, mult):
# size = np.array(self.img.size)
# padding = np.concatenate([[0,0],np.ceil(size / mult).astype(int) * mult - size])
# return self.pad(list(padding))
# def crop(self, *args):
# return BetterImage(self.img.crop(*args))
# def torch(self, min = -1., max = 1.):
# return (max - min) * TF.to_tensor(self.img)[None] + min