|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import DataLoader, Dataset |
|
from torchvision import datasets, transforms |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
from model import ColorNet |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.ToTensor() |
|
]) |
|
|
|
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True) |
|
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True) |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) |
|
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) |
|
|
|
model = ColorNet() |
|
criterion = nn.MSELoss() |
|
optimizer = optim.Adam(model.parameters(), lr=1e-3) |
|
|
|
model.train_model(model, train_loader, criterion, optimizer, num_epochs=10) |
|
|