NewBreaker
commited on
Commit
•
b83d9ec
1
Parent(s):
47d7bda
auto git
Browse files- tools/ResNet_MNIST.py +206 -0
- tools/ResNet_cat_vs_dog_Ram.py +15 -0
- tools/data_test.py +80 -0
- tools/data_train.py +47 -0
- tools/将数据集按照比例进行拆分.py +84 -0
tools/ResNet_MNIST.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# In[1] 导入所需工具包
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchvision
|
5 |
+
from torchvision import datasets, transforms
|
6 |
+
import time
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from math import floor, ceil
|
9 |
+
from torch.utils.data import DataLoader,TensorDataset
|
10 |
+
# import torchvision.transforms as transforms
|
11 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
+
print(device)
|
13 |
+
# In[1] 设置超参数
|
14 |
+
num_epochs = 60
|
15 |
+
batch_size = 1000
|
16 |
+
learning_rate = 0.001
|
17 |
+
|
18 |
+
# In[2] 获取数据包括训练数据和测试数据
|
19 |
+
|
20 |
+
transform = torchvision.transforms.Compose([
|
21 |
+
torchvision.transforms.ToTensor(),
|
22 |
+
torchvision.transforms.Normalize(
|
23 |
+
(0.1307,), (0.3081,))
|
24 |
+
])
|
25 |
+
|
26 |
+
|
27 |
+
train_set = torchvision.datasets.MNIST(root='MNIST', train=True, download=True)
|
28 |
+
train_data = train_set.data.float().unsqueeze(1) / 255.0
|
29 |
+
train_labels = train_set.targets
|
30 |
+
train_dataset = TensorDataset(train_data,train_labels)
|
31 |
+
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
|
32 |
+
|
33 |
+
|
34 |
+
test_set = torchvision.datasets.MNIST(root='MNIST', train=False, download=True)
|
35 |
+
test_data = test_set.data.float().unsqueeze(1) / 255.0
|
36 |
+
test_labels = test_set.targets
|
37 |
+
test_dataset = TensorDataset(test_data,test_labels)
|
38 |
+
test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=True)
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
# In[1] 定义卷积核
|
43 |
+
def conv3x3(in_channels, out_channels, stride=1):
|
44 |
+
return nn.Conv2d(in_channels, out_channels, kernel_size=3,
|
45 |
+
stride=stride, padding=1, bias=True)
|
46 |
+
|
47 |
+
|
48 |
+
# In[1] 定义残差块
|
49 |
+
class ResidualBlock(nn.Module):
|
50 |
+
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
|
51 |
+
super(ResidualBlock, self).__init__()
|
52 |
+
self.conv1 = conv3x3(in_channels, out_channels, stride)
|
53 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
54 |
+
self.relu = nn.ReLU(inplace=True)
|
55 |
+
self.conv2 = conv3x3(out_channels, out_channels)
|
56 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
57 |
+
self.downsample = downsample
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
residual = x
|
61 |
+
out = self.conv1(x)
|
62 |
+
out = self.bn1(out)
|
63 |
+
out = self.relu(out)
|
64 |
+
out = self.conv2(out)
|
65 |
+
out = self.bn2(out)
|
66 |
+
# 下采样
|
67 |
+
if self.downsample:
|
68 |
+
residual = self.downsample(x)
|
69 |
+
out += residual
|
70 |
+
out = self.relu(out)
|
71 |
+
return out
|
72 |
+
|
73 |
+
|
74 |
+
# In[1] 搭建残差神经网络
|
75 |
+
class ResNet(nn.Module):
|
76 |
+
def __init__(self, block, layers, num_classes=10):
|
77 |
+
super(ResNet, self).__init__()
|
78 |
+
self.in_channels = 16
|
79 |
+
self.conv = conv3x3(1, 16)
|
80 |
+
self.bn = nn.BatchNorm2d(16)
|
81 |
+
self.relu = nn.ReLU(inplace=True)
|
82 |
+
# 构建残差块,恒等映射
|
83 |
+
# in_channels == out_channels and stride = 1 所以这里我们构建残差块,没有下采样
|
84 |
+
self.layer1 = self.make_layer(block, 16, layers[0], stride=1)
|
85 |
+
# 不构建残差块,进行了下采样
|
86 |
+
# layers中记录的是数字,表示对应位置的残差块数目
|
87 |
+
self.layer2 = self.make_layer(block, 32, layers[1], 2)
|
88 |
+
# 不构建残差块,进行了下采样
|
89 |
+
self.layer3 = self.make_layer(block, 64, layers[2], 2)
|
90 |
+
self.avg_pool = nn.AvgPool2d(8)
|
91 |
+
self.fc1 = nn.Linear(3136, 128)
|
92 |
+
self.normfc12 = nn.LayerNorm((128), eps=1e-5)
|
93 |
+
self.fc2 = nn.Linear(128, num_classes)
|
94 |
+
|
95 |
+
def make_layer(self, block, out_channels, blocks, stride=1):
|
96 |
+
downsample = None
|
97 |
+
if (stride != 1) or (self.in_channels != out_channels):
|
98 |
+
downsample = nn.Sequential(
|
99 |
+
conv3x3(self.in_channels, out_channels, stride=stride),
|
100 |
+
nn.BatchNorm2d(out_channels))
|
101 |
+
layers = []
|
102 |
+
layers.append(block(self.in_channels, out_channels, stride, downsample))
|
103 |
+
# 当out_channels = 32时,in_channels也变成32了
|
104 |
+
self.in_channels = out_channels
|
105 |
+
# blocks是残差块的数目
|
106 |
+
# 残差块之后的网络结构,是out_channels->out_channels的
|
107 |
+
# 可以说,make_layer做的是输出尺寸相同的所有网络结构
|
108 |
+
# 由于输出尺寸会改变,我们用make_layers去生成一大块对应尺寸完整网络结构
|
109 |
+
for i in range(1, blocks):
|
110 |
+
layers.append(block(out_channels, out_channels))
|
111 |
+
return nn.Sequential(*layers)
|
112 |
+
|
113 |
+
def forward(self, x):
|
114 |
+
out = self.conv(x)
|
115 |
+
out = self.bn(out)
|
116 |
+
out = self.relu(out)
|
117 |
+
# layer1是三块in_channels等于16的网络结构,包括三个恒等映射
|
118 |
+
out = self.layer1(out)
|
119 |
+
# layer2包括了16->32下采样,然后是32的三个恒等映射
|
120 |
+
out = self.layer2(out)
|
121 |
+
# layer3包括了32->64的下采样,然后是64的三个恒等映射
|
122 |
+
out = self.layer3(out)
|
123 |
+
# out = self.avg_pool(out)
|
124 |
+
# 全连接压缩
|
125 |
+
# out.size(0)可以看作是batch_size
|
126 |
+
out = out.view(out.size(0), -1)
|
127 |
+
out = self.fc1(out)
|
128 |
+
out = self.normfc12(out)
|
129 |
+
out = self.relu(out)
|
130 |
+
out = self.fc2(out)
|
131 |
+
return out
|
132 |
+
|
133 |
+
|
134 |
+
# In[1] 定义模型和损失函数
|
135 |
+
# [2,2,2]表示的是不同in_channels下的恒等映射数目
|
136 |
+
model = ResNet(ResidualBlock, [2, 2, 2]).to(device)
|
137 |
+
criterion = nn.CrossEntropyLoss()
|
138 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
139 |
+
|
140 |
+
|
141 |
+
# In[1] 设置一个通过优化器更新学习率的函数
|
142 |
+
def update_lr(optimizer, lr):
|
143 |
+
for param_group in optimizer.param_groups:
|
144 |
+
param_group['lr'] = lr
|
145 |
+
|
146 |
+
|
147 |
+
# In[1] 定义测试函数
|
148 |
+
def test(model, test_loader):
|
149 |
+
model.eval()
|
150 |
+
with torch.no_grad():
|
151 |
+
correct = 0
|
152 |
+
total = 0
|
153 |
+
for images, labels in test_loader:
|
154 |
+
images = images.to(device)
|
155 |
+
labels = labels.to(device)
|
156 |
+
outputs = model(images)
|
157 |
+
_, predicted = torch.max(outputs.data, 1)
|
158 |
+
total += labels.size(0)
|
159 |
+
correct += (predicted == labels).sum().item()
|
160 |
+
|
161 |
+
print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))
|
162 |
+
|
163 |
+
|
164 |
+
# In[1] 训练模型更新学习率
|
165 |
+
total_step = len(train_loader)
|
166 |
+
curr_lr = learning_rate
|
167 |
+
for epoch in range(num_epochs):
|
168 |
+
in_epoch = time.time()
|
169 |
+
for i, (images, labels) in enumerate(train_loader):
|
170 |
+
images = images.to(device)
|
171 |
+
labels = labels.to(device)
|
172 |
+
|
173 |
+
# Forward pass
|
174 |
+
outputs = model(images)
|
175 |
+
loss = criterion(outputs, labels)
|
176 |
+
|
177 |
+
# Backward and optimize
|
178 |
+
optimizer.zero_grad()
|
179 |
+
loss.backward()
|
180 |
+
optimizer.step()
|
181 |
+
|
182 |
+
if (i + 1) % 100 == 0:
|
183 |
+
print("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}"
|
184 |
+
.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
|
185 |
+
test(model, test_loader)
|
186 |
+
out_epoch = time.time()
|
187 |
+
print(f"use {(out_epoch - in_epoch) // 60}min{(out_epoch - in_epoch) % 60}s")
|
188 |
+
if (epoch + 1) % 20 == 0:
|
189 |
+
curr_lr /= 3
|
190 |
+
update_lr(optimizer, curr_lr)
|
191 |
+
# In[1] 测试模型并保存
|
192 |
+
model.eval()
|
193 |
+
with torch.no_grad():
|
194 |
+
correct = 0
|
195 |
+
total = 0
|
196 |
+
for images, labels in test_loader:
|
197 |
+
images = images.to(device)
|
198 |
+
labels = labels.to(device)
|
199 |
+
outputs = model(images)
|
200 |
+
_, predicted = torch.max(outputs.data, 1)
|
201 |
+
total += labels.size(0)
|
202 |
+
correct += (predicted == labels).sum().item()
|
203 |
+
|
204 |
+
print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))
|
205 |
+
|
206 |
+
torch.save(model.state_dict(), '../resnet.ckpt')
|
tools/ResNet_cat_vs_dog_Ram.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import datasets, transforms
|
3 |
+
from torch.utils.data import DataLoader,TensorDataset
|
4 |
+
|
5 |
+
transform = transforms.Compose([
|
6 |
+
transforms.Resize((512, 512)),
|
7 |
+
transforms.ToTensor(),
|
8 |
+
])
|
9 |
+
|
10 |
+
# 加载训练集和测试集
|
11 |
+
train_set = datasets.ImageFolder(root='data/cat_vs_dog/train', transform=transform)
|
12 |
+
test_set = datasets.ImageFolder(root='data/cat_vs_dog/test', transform=transform)
|
13 |
+
|
14 |
+
train_data = train_set.imgs
|
15 |
+
print("train_data:", train_data)
|
tools/data_test.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
import cv2
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
|
9 |
+
import torchvision
|
10 |
+
from torchvision import models,transforms,datasets
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch import optim
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
import os
|
20 |
+
import shutil
|
21 |
+
import random
|
22 |
+
def make_dir(path):
|
23 |
+
import os
|
24 |
+
dir = os.path.exists(path)
|
25 |
+
if not dir:
|
26 |
+
os.makedirs(path)
|
27 |
+
|
28 |
+
|
29 |
+
def get_filename_and_houzhui(full_path):
|
30 |
+
import os
|
31 |
+
path, file_full_name = os.path.split(full_path)
|
32 |
+
file_name, 后缀名 = os.path.splitext(file_full_name)
|
33 |
+
return path,file_name,后缀名
|
34 |
+
|
35 |
+
|
36 |
+
dataset_root_path = '../data/cat_vs_dog'
|
37 |
+
train_path_cat_new = os.path.join(dataset_root_path, 'new/train/cat')
|
38 |
+
train_path_dog_new = os.path.join(dataset_root_path, 'new/train/dog')
|
39 |
+
|
40 |
+
test_path_cat_new = os.path.join(dataset_root_path, 'new/test/cat')
|
41 |
+
test_path_dog_new = os.path.join(dataset_root_path, 'new/test/dog')
|
42 |
+
|
43 |
+
make_dir(train_path_cat_new)
|
44 |
+
make_dir(train_path_dog_new)
|
45 |
+
make_dir(test_path_cat_new)
|
46 |
+
make_dir(test_path_dog_new)
|
47 |
+
|
48 |
+
image_dir_path = os.path.join(dataset_root_path,'train')
|
49 |
+
image_name_list = os.listdir(image_dir_path)
|
50 |
+
for image_name in tqdm(image_name_list):
|
51 |
+
image_path = os.path.join(image_dir_path,image_name)
|
52 |
+
path, file_name, 后缀名 = get_filename_and_houzhui(full_path=image_path)
|
53 |
+
# print("file_name:", file_name)
|
54 |
+
# 定义随机数的范围和对应的概率
|
55 |
+
nums = [1, 2]
|
56 |
+
probs = [0.9, 0.1] #设定训练集和测试集的比率
|
57 |
+
|
58 |
+
random_nums = random.choices(nums, weights=probs)[0]
|
59 |
+
|
60 |
+
if(random_nums == 1): #摇筛子如果摇到了1,那么就是训练集
|
61 |
+
|
62 |
+
if('cat' in file_name):
|
63 |
+
shutil.copy(image_path, train_path_cat_new)
|
64 |
+
elif('dog' in file_name):
|
65 |
+
shutil.copy(image_path, train_path_dog_new)
|
66 |
+
elif(random_nums == 2): #摇骰子如果摇到了2,那么就是测试集
|
67 |
+
if('cat' in file_name):
|
68 |
+
shutil.copy(image_path, test_path_cat_new)
|
69 |
+
elif('dog' in file_name):
|
70 |
+
shutil.copy(image_path, test_path_dog_new)
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
|
tools/data_train.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
import os
|
5 |
+
import shutil
|
6 |
+
|
7 |
+
|
8 |
+
def make_dir(path):
|
9 |
+
import os
|
10 |
+
dir = os.path.exists(path)
|
11 |
+
if not dir:
|
12 |
+
os.makedirs(path)
|
13 |
+
|
14 |
+
|
15 |
+
def get_filename_and_houzhui(full_path):
|
16 |
+
import os
|
17 |
+
path, file_full_name = os.path.split(full_path)
|
18 |
+
file_name, 后缀名 = os.path.splitext(file_full_name)
|
19 |
+
return path,file_name,后缀名
|
20 |
+
|
21 |
+
|
22 |
+
dataset_root_path = '../data/cat_vs_dog'
|
23 |
+
train_path_cat_new = os.path.join(dataset_root_path, 'new/train/cat')
|
24 |
+
train_path_dog_new = os.path.join(dataset_root_path, 'new/train/dog')
|
25 |
+
make_dir(train_path_cat_new)
|
26 |
+
make_dir(train_path_dog_new)
|
27 |
+
|
28 |
+
image_dir_path = os.path.join(dataset_root_path,'train')
|
29 |
+
image_name_list = os.listdir(image_dir_path)
|
30 |
+
for image_name in image_name_list:
|
31 |
+
image_path = os.path.join(image_dir_path,image_name)
|
32 |
+
path, file_name, 后缀名 = get_filename_and_houzhui(full_path=image_path)
|
33 |
+
print("file_name:", file_name)
|
34 |
+
|
35 |
+
if('cat' in file_name):
|
36 |
+
shutil.copy(image_path,train_path_cat_new)
|
37 |
+
elif('dog' in file_name):
|
38 |
+
shutil.copy(image_path, train_path_dog_new)
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
|
tools/将数据集按照比例进行拆分.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from tqdm import tqdm
|
7 |
+
import cv2
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
import torchvision
|
11 |
+
from torchvision import models,transforms,datasets
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch import optim
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
import os
|
21 |
+
import shutil
|
22 |
+
import random
|
23 |
+
def make_dir(path):
|
24 |
+
import os
|
25 |
+
dir = os.path.exists(path)
|
26 |
+
if not dir:
|
27 |
+
os.makedirs(path)
|
28 |
+
|
29 |
+
|
30 |
+
def get_filename_and_houzhui(full_path):
|
31 |
+
import os
|
32 |
+
path, file_full_name = os.path.split(full_path)
|
33 |
+
file_name, 后缀名 = os.path.splitext(file_full_name)
|
34 |
+
return path,file_name,后缀名
|
35 |
+
|
36 |
+
|
37 |
+
dataset_root_path = '../data/cat_vs_dog'
|
38 |
+
train_path_cat_new = os.path.join(dataset_root_path, 'new/train/cat')
|
39 |
+
train_path_dog_new = os.path.join(dataset_root_path, 'new/train/dog')
|
40 |
+
|
41 |
+
test_path_cat_new = os.path.join(dataset_root_path, 'new/test/cat')
|
42 |
+
test_path_dog_new = os.path.join(dataset_root_path, 'new/test/dog')
|
43 |
+
|
44 |
+
make_dir(train_path_cat_new)
|
45 |
+
make_dir(train_path_dog_new)
|
46 |
+
make_dir(test_path_cat_new)
|
47 |
+
make_dir(test_path_dog_new)
|
48 |
+
|
49 |
+
image_dir_path = os.path.join(dataset_root_path,'train')
|
50 |
+
image_name_list = os.listdir(image_dir_path)
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
for image_name in tqdm(image_name_list):
|
55 |
+
image_path = os.path.join(image_dir_path,image_name)
|
56 |
+
path, file_name, 后缀名 = get_filename_and_houzhui(full_path=image_path)
|
57 |
+
# print("file_name:", file_name)
|
58 |
+
# 定义随机数的范围和对应的概率
|
59 |
+
nums = [1, 2]
|
60 |
+
probs = [0.9, 0.1] #设定训练集和测试集的比率
|
61 |
+
|
62 |
+
random_nums = random.choices(nums, weights=probs)[0]
|
63 |
+
|
64 |
+
if(random_nums == 1): #摇筛子如果摇到了1,那么就是训练集
|
65 |
+
|
66 |
+
if('cat' in file_name):
|
67 |
+
shutil.copy(image_path, train_path_cat_new)
|
68 |
+
elif('dog' in file_name):
|
69 |
+
shutil.copy(image_path, train_path_dog_new)
|
70 |
+
elif(random_nums == 2): #摇骰子如果摇到了2,那么就是测试集
|
71 |
+
if('cat' in file_name):
|
72 |
+
shutil.copy(image_path, test_path_cat_new)
|
73 |
+
elif('dog' in file_name):
|
74 |
+
shutil.copy(image_path, test_path_dog_new)
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|