import torch.nn as nn class MangaColorizer(nn.Module): def __init__(self): super(MangaColorizer, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), nn.ReLU(inplace=True) ) self.decoder = nn.Sequential( nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), nn.ReLU(inplace=True), nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1), nn.Tanh() ) def forward(self, x): x = self.encoder(x) x = self.decoder(x) return x