Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import pickle | |
import os | |
dir_path = os.path.dirname(os.path.realpath(__file__)) | |
# from tensorboardX import SummaryWriter | |
from tqdm import tqdm | |
import datetime | |
from torch.utils.data import DataLoader, TensorDataset | |
date = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') | |
class BoardEvaluationNet(nn.Module): | |
def __init__(self, board_size): | |
super(BoardEvaluationNet, self).__init__() | |
self.board_size = board_size | |
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1) | |
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) | |
self.fc1 = nn.Linear(32 * board_size * board_size, 256) | |
self.fc2 = nn.Linear(256, board_size * board_size) | |
def forward(self, x): | |
x = x.unsqueeze(1) # Add a channel dimension | |
x = F.relu(self.conv1(x)) | |
x = F.relu(self.conv2(x)) | |
x = x.view(-1, 32 * self.board_size * self.board_size) | |
x = F.relu(self.fc1(x)) | |
x = self.fc2(x) | |
return x.view(-1, self.board_size, self.board_size) | |
def normalize(t): | |
return t | |
if __name__ == "__main__": | |
writer = SummaryWriter(os.path.join(dir_path, 'train_data/log', date), comment='BoardEvaluationNet') | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
best = np.Inf | |
loss_fn = nn.CrossEntropyLoss() | |
# Example usage | |
BS = 15 | |
net_for_black = BoardEvaluationNet(BS).to(device) | |
net_for_white = BoardEvaluationNet(BS).to(device) | |
net_for_black.load_state_dict(torch.load(os.path.join(dir_path, 'train_data/model', 'best_loss=680.5813717259707.pth'))) | |
optimizer = torch.optim.Adam(net_for_black.parameters(), lr=1e-5, betas=(0.9, 0.99), | |
eps=1e-8) | |
data_path = os.path.join(dir_path, 'train_data/data', 'train_data.pkl') | |
with open(data_path, 'rb') as f: | |
datas = pickle.load(f) | |
train_data_for_black = datas[1][:int(len(datas[1]) * 1)] | |
test_data_for_black = datas[1][int(len(datas[1]) * 0.8):] | |
train_data_for_white = datas[-1] | |
epochs = 500 | |
batch_size = 32 | |
train_dataset = TensorDataset(torch.stack([torch.tensor(item['state'], dtype=torch.float) for item in train_data_for_black]), | |
torch.stack([normalize(torch.tensor(item['scores'], dtype=torch.float)) for item in train_data_for_black])) | |
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
for epoch in range(epochs): | |
epoch_loss = 0 | |
print('Epoch:', epoch) | |
for i, (states, scores) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)): | |
states = states.to(device) | |
scores = scores.to(device) | |
# print(input_tensor.shape) | |
infer_start = datetime.datetime.now() | |
output_tensor = net_for_black(states) | |
infer_end = datetime.datetime.now() | |
loss = loss_fn(output_tensor, scores) | |
print(loss.item()) | |
exit(0) | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
epoch_loss += loss.item() | |
writer.add_scalar('train/infer_time', (infer_end - infer_start).microseconds, | |
i + epoch * len(train_dataloader)) | |
epoch_loss /= len(train_dataloader) | |
writer.add_scalar('train/epoch_loss', epoch_loss, epoch) | |
# test | |
with torch.no_grad(): | |
test_loss = 0 | |
net_for_black.eval() | |
for j, item in tqdm(enumerate(test_data_for_black), total=len(test_data_for_black)): | |
scores = normalize(torch.tensor(item['scores'], dtype=torch.float).to(device).unsqueeze(0)) # 将数据类型设为float | |
state = item['state'] | |
input_tensor = torch.tensor(state, dtype=torch.float).to(device).unsqueeze(0) # 将数据类型设为float,并转移到设备上 | |
output_tensor = net_for_black(input_tensor).to(device) | |
loss = loss_fn(output_tensor, scores) | |
test_loss += loss.item() | |
test_loss /=len(test_data_for_black) | |
writer.add_scalar('test/loss', test_loss, epoch) | |
if best > test_loss: | |
best = test_loss | |
model_path = os.path.join(dir_path, 'train_data/model') | |
if not os.path.exists(model_path): | |
os.makedirs(model_path) | |
torch.save(net_for_black.state_dict(), | |
os.path.join(model_path, f'best_loss={best}.pth')) | |
net_for_black.train() |