import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset from torchvision import datasets, transforms import matplotlib.pyplot as plt import numpy as np from model import ColorNet transform = transforms.Compose([ transforms.ToTensor() ]) train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True) test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) model = ColorNet() criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=1e-3) model.train_model(model, train_loader, criterion, optimizer, num_epochs=10)