from PIL import Image import cv2 import numpy as np import torch import torch.nn as nn import torchvision.transforms as transforms # Variables image_size = 32 classes = ['😠', '🙁', '🙃', '🙂', '😎'] # Functions def image_transform(image, augment=True): list = [] if augment: shear = image.size[0] * .035 list.append(transforms.RandomAffine( degrees = (-22.5, 22.5), shear = (-shear, shear), fill = 255 )) def minmax(v): return int(v * random.uniform(0.66, 1.1)) list.append(transforms.Resize((minmax(image.size[0]), minmax(image.size[1])))) list.append(TransformExpandToSquare()) list.append(TransformCrop()) list.append(transforms.Grayscale(num_output_channels=1)) list.append(transforms.Resize(image_size)) list.append(transforms.ColorJitter(contrast=(10,10))) list.append(transforms.ToTensor()) #list.append(transforms.Normalize([0.5], [0.5])) # transforms.Normalize(mean=0.0007226799, std=0.0004196318)) transform = transforms.Compose(list) return transform(image) def extract_image_roi(img): img_cv = np.array(img.convert('RGB'))[:, :, ::-1].copy() img_gray = cv2.cvtColor(img_cv, cv2.COLOR_RGB2GRAY) contours, _ = cv2.findContours(img_gray, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)[-2:] cnt = contours[-2] x,y,w,h = cv2.boundingRect(cnt) roi = img_cv[y:y+h,x:x+w] conv = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB) return Image.fromarray(conv) def expand_to_square(img, bg_color): if img.size[0] == img.size[1]: return img max_size = max(img.size) min_size = min(img.size) result = Image.new(img.mode, (max_size, max_size), bg_color) offset = (max_size - min_size) // 2 xy = (0, offset) if img.size[0] > img.size[1] else (offset, 0) result.paste(img, xy) return result def crop(img, bg_color=(255, 255, 255)): return expand_to_square(extract_image_roi(img), bg_color) # Classes class Model(nn.Module): def __init__(self, channels_in, classes): super().__init__() self.conv1 = nn.Conv2d(channels_in, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(400, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, classes) def forward(self, x): x = self.pool(nn.functional.relu(self.conv1(x))) x = self.pool(nn.functional.relu(self.conv2(x))) x = torch.flatten(x, 1) x = nn.functional.relu(self.fc1(x)) x = nn.functional.relu(self.fc2(x)) x = self.fc3(x) return x class TransformCrop: def __call__(self, img): return crop(img) class TransformExpandToSquare: def __call__(self, img): return expand_to_square(img, (255, 255, 255))