1.1 下载数据集¶
In [2]:
import torch
torch.cuda.is_available()
Out[2]:
True
In [16]:
from model import *
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import *
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_set = torchvision.datasets.CIFAR10(root='../../../data/',
train=True,
download=False,
transform=transform)
train_loader = DataLoader(train_set, batch_size=36, shuffle=True, num_workers=0)
In [39]:
val_set = torchvision.datasets.CIFAR10(root='../../../data/',
train=False,
download=False,
transform=transform)
val_loader = DataLoader(val_set, batch_size=10000, shuffle=True, num_workers=0)
In [54]:
var_data_iter = iter(val_loader)
val_img, val_label = var_data_iter.next()
1.2 看一下数据集¶
In [53]:
import numpy as np
import matplotlib.pyplot as plt
def imshow(img):
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1,2,0)))
plt.show()
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
test_data_iter = iter(val_loader)
test_img, test_label = test_data_iter.next()
test_img = test_img[:10]; test_label = test_label[:10]
imshow(torchvision.utils.make_grid(test_img))
print(list(map(lambda x:classes[x], test_label.numpy())))
['horse', 'deer', 'bird', 'ship', 'cat', 'frog', 'car', 'truck', 'frog', 'ship']
2.1 开始训练¶
In [58]:
import torch.optim as optim
net = LeNet()
loss_function = nn.CrossEntropyLoss()
# 优化器
optimizer = optim.Adam(net.parameters(), lr = 0.001)
for epoch in range(50):
running_loss = 0.0
for step, data in enumerate(train_loader, start=0):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if step % 200 == 199:
with torch.no_grad():
outputs = net(val_img)
pridict_y = torch.max(outputs, dim=1)[1]
acc = torch.eq(pridict_y, val_label).sum().item() / val_label.size(0)
print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %
(epoch + 1, step + 1, running_loss / 500, acc))
running_loss = 0.0
print("Finished Training!")
save_path = "./LeNet.pth"
torch.save(net.state_dict(), save_path)
[1, 200] train_loss: 0.776 test_accuracy: 0.346 [1, 400] train_loss: 0.654 test_accuracy: 0.432 [1, 600] train_loss: 0.596 test_accuracy: 0.462 [1, 800] train_loss: 0.579 test_accuracy: 0.497 [1, 1000] train_loss: 0.550 test_accuracy: 0.505 [1, 1200] train_loss: 0.526 test_accuracy: 0.527 [2, 200] train_loss: 0.501 test_accuracy: 0.547 [2, 400] train_loss: 0.486 test_accuracy: 0.554 [2, 600] train_loss: 0.478 test_accuracy: 0.580 [2, 800] train_loss: 0.470 test_accuracy: 0.576 [2, 1000] train_loss: 0.459 test_accuracy: 0.590 [2, 1200] train_loss: 0.446 test_accuracy: 0.594 [3, 200] train_loss: 0.427 test_accuracy: 0.610 [3, 400] train_loss: 0.415 test_accuracy: 0.615 [3, 600] train_loss: 0.415 test_accuracy: 0.609 [3, 800] train_loss: 0.411 test_accuracy: 0.623 [3, 1000] train_loss: 0.404 test_accuracy: 0.630 [3, 1200] train_loss: 0.406 test_accuracy: 0.638 [4, 200] train_loss: 0.361 test_accuracy: 0.646 [4, 400] train_loss: 0.365 test_accuracy: 0.646 [4, 600] train_loss: 0.373 test_accuracy: 0.634 [4, 800] train_loss: 0.373 test_accuracy: 0.640 [4, 1000] train_loss: 0.370 test_accuracy: 0.653 [4, 1200] train_loss: 0.359 test_accuracy: 0.660 [5, 200] train_loss: 0.333 test_accuracy: 0.662 [5, 400] train_loss: 0.334 test_accuracy: 0.656 [5, 600] train_loss: 0.338 test_accuracy: 0.653 [5, 800] train_loss: 0.336 test_accuracy: 0.657 [5, 1000] train_loss: 0.348 test_accuracy: 0.660 [5, 1200] train_loss: 0.339 test_accuracy: 0.673 [6, 200] train_loss: 0.300 test_accuracy: 0.657 [6, 400] train_loss: 0.314 test_accuracy: 0.672 [6, 600] train_loss: 0.314 test_accuracy: 0.658 [6, 800] train_loss: 0.314 test_accuracy: 0.677 [6, 1000] train_loss: 0.317 test_accuracy: 0.677 [6, 1200] train_loss: 0.307 test_accuracy: 0.670 [7, 200] train_loss: 0.277 test_accuracy: 0.678 [7, 400] train_loss: 0.282 test_accuracy: 0.675 [7, 600] train_loss: 0.285 test_accuracy: 0.668 [7, 800] train_loss: 0.296 test_accuracy: 0.676 [7, 1000] train_loss: 0.302 test_accuracy: 0.679 [7, 1200] train_loss: 0.299 test_accuracy: 0.684 [8, 200] train_loss: 0.258 test_accuracy: 0.671 [8, 400] train_loss: 0.265 test_accuracy: 0.679 [8, 600] train_loss: 0.265 test_accuracy: 0.675 [8, 800] train_loss: 0.274 test_accuracy: 0.690 [8, 1000] train_loss: 0.268 test_accuracy: 0.683 [8, 1200] train_loss: 0.275 test_accuracy: 0.693 [9, 200] train_loss: 0.241 test_accuracy: 0.693 [9, 400] train_loss: 0.244 test_accuracy: 0.680 [9, 600] train_loss: 0.261 test_accuracy: 0.682 [9, 800] train_loss: 0.245 test_accuracy: 0.679 [9, 1000] train_loss: 0.256 test_accuracy: 0.683 [9, 1200] train_loss: 0.256 test_accuracy: 0.686 [10, 200] train_loss: 0.222 test_accuracy: 0.679 [10, 400] train_loss: 0.233 test_accuracy: 0.688 [10, 600] train_loss: 0.237 test_accuracy: 0.678 [10, 800] train_loss: 0.244 test_accuracy: 0.689 [10, 1000] train_loss: 0.241 test_accuracy: 0.689 [10, 1200] train_loss: 0.241 test_accuracy: 0.689 [11, 200] train_loss: 0.201 test_accuracy: 0.685 [11, 400] train_loss: 0.218 test_accuracy: 0.684 [11, 600] train_loss: 0.223 test_accuracy: 0.692 [11, 800] train_loss: 0.226 test_accuracy: 0.691 [11, 1000] train_loss: 0.224 test_accuracy: 0.682 [11, 1200] train_loss: 0.229 test_accuracy: 0.689 [12, 200] train_loss: 0.185 test_accuracy: 0.687 [12, 400] train_loss: 0.204 test_accuracy: 0.681 [12, 600] train_loss: 0.203 test_accuracy: 0.689 [12, 800] train_loss: 0.217 test_accuracy: 0.687 [12, 1000] train_loss: 0.217 test_accuracy: 0.685 [12, 1200] train_loss: 0.220 test_accuracy: 0.689 [13, 200] train_loss: 0.181 test_accuracy: 0.684 [13, 400] train_loss: 0.191 test_accuracy: 0.683 [13, 600] train_loss: 0.200 test_accuracy: 0.678 [13, 800] train_loss: 0.201 test_accuracy: 0.679 [13, 1000] train_loss: 0.198 test_accuracy: 0.681 [13, 1200] train_loss: 0.201 test_accuracy: 0.686 [14, 200] train_loss: 0.167 test_accuracy: 0.687 [14, 400] train_loss: 0.178 test_accuracy: 0.683 [14, 600] train_loss: 0.178 test_accuracy: 0.681 [14, 800] train_loss: 0.191 test_accuracy: 0.678 [14, 1000] train_loss: 0.187 test_accuracy: 0.682 [14, 1200] train_loss: 0.196 test_accuracy: 0.687 [15, 200] train_loss: 0.156 test_accuracy: 0.688 [15, 400] train_loss: 0.160 test_accuracy: 0.684 [15, 600] train_loss: 0.169 test_accuracy: 0.684 [15, 800] train_loss: 0.185 test_accuracy: 0.678 [15, 1000] train_loss: 0.180 test_accuracy: 0.670 [15, 1200] train_loss: 0.182 test_accuracy: 0.683 [16, 200] train_loss: 0.147 test_accuracy: 0.679 [16, 400] train_loss: 0.159 test_accuracy: 0.683 [16, 600] train_loss: 0.161 test_accuracy: 0.679 [16, 800] train_loss: 0.163 test_accuracy: 0.678 [16, 1000] train_loss: 0.167 test_accuracy: 0.683 [16, 1200] train_loss: 0.186 test_accuracy: 0.678 [17, 200] train_loss: 0.139 test_accuracy: 0.687 [17, 400] train_loss: 0.147 test_accuracy: 0.679 [17, 600] train_loss: 0.157 test_accuracy: 0.679 [17, 800] train_loss: 0.156 test_accuracy: 0.681 [17, 1000] train_loss: 0.161 test_accuracy: 0.675 [17, 1200] train_loss: 0.162 test_accuracy: 0.675 [18, 200] train_loss: 0.123 test_accuracy: 0.678 [18, 400] train_loss: 0.132 test_accuracy: 0.682 [18, 600] train_loss: 0.141 test_accuracy: 0.673 [18, 800] train_loss: 0.142 test_accuracy: 0.667 [18, 1000] train_loss: 0.160 test_accuracy: 0.679 [18, 1200] train_loss: 0.151 test_accuracy: 0.677 [19, 200] train_loss: 0.122 test_accuracy: 0.673 [19, 400] train_loss: 0.128 test_accuracy: 0.680 [19, 600] train_loss: 0.136 test_accuracy: 0.682 [19, 800] train_loss: 0.142 test_accuracy: 0.682 [19, 1000] train_loss: 0.148 test_accuracy: 0.674 [19, 1200] train_loss: 0.146 test_accuracy: 0.679 [20, 200] train_loss: 0.115 test_accuracy: 0.675 [20, 400] train_loss: 0.117 test_accuracy: 0.672 [20, 600] train_loss: 0.128 test_accuracy: 0.670 [20, 800] train_loss: 0.127 test_accuracy: 0.674 [20, 1000] train_loss: 0.136 test_accuracy: 0.671 [20, 1200] train_loss: 0.146 test_accuracy: 0.668 [21, 200] train_loss: 0.105 test_accuracy: 0.673 [21, 400] train_loss: 0.120 test_accuracy: 0.679 [21, 600] train_loss: 0.116 test_accuracy: 0.677 [21, 800] train_loss: 0.127 test_accuracy: 0.677 [21, 1000] train_loss: 0.134 test_accuracy: 0.674 [21, 1200] train_loss: 0.131 test_accuracy: 0.669 [22, 200] train_loss: 0.096 test_accuracy: 0.665 [22, 400] train_loss: 0.106 test_accuracy: 0.675 [22, 600] train_loss: 0.117 test_accuracy: 0.671 [22, 800] train_loss: 0.117 test_accuracy: 0.684 [22, 1000] train_loss: 0.115 test_accuracy: 0.675 [22, 1200] train_loss: 0.130 test_accuracy: 0.677 [23, 200] train_loss: 0.098 test_accuracy: 0.676 [23, 400] train_loss: 0.107 test_accuracy: 0.671 [23, 600] train_loss: 0.107 test_accuracy: 0.673 [23, 800] train_loss: 0.113 test_accuracy: 0.670 [23, 1000] train_loss: 0.116 test_accuracy: 0.670 [23, 1200] train_loss: 0.121 test_accuracy: 0.671 [24, 200] train_loss: 0.082 test_accuracy: 0.672 [24, 400] train_loss: 0.096 test_accuracy: 0.674 [24, 600] train_loss: 0.110 test_accuracy: 0.669 [24, 800] train_loss: 0.115 test_accuracy: 0.675 [24, 1000] train_loss: 0.107 test_accuracy: 0.671 [24, 1200] train_loss: 0.121 test_accuracy: 0.673 [25, 200] train_loss: 0.082 test_accuracy: 0.674 [25, 400] train_loss: 0.092 test_accuracy: 0.673 [25, 600] train_loss: 0.099 test_accuracy: 0.673 [25, 800] train_loss: 0.105 test_accuracy: 0.672 [25, 1000] train_loss: 0.101 test_accuracy: 0.673 [25, 1200] train_loss: 0.108 test_accuracy: 0.674 [26, 200] train_loss: 0.085 test_accuracy: 0.670 [26, 400] train_loss: 0.087 test_accuracy: 0.664 [26, 600] train_loss: 0.100 test_accuracy: 0.658 [26, 800] train_loss: 0.095 test_accuracy: 0.670 [26, 1000] train_loss: 0.101 test_accuracy: 0.667 [26, 1200] train_loss: 0.110 test_accuracy: 0.663 [27, 200] train_loss: 0.076 test_accuracy: 0.671 [27, 400] train_loss: 0.081 test_accuracy: 0.671 [27, 600] train_loss: 0.091 test_accuracy: 0.662
[27, 800] train_loss: 0.093 test_accuracy: 0.663 [27, 1000] train_loss: 0.093 test_accuracy: 0.670 [27, 1200] train_loss: 0.098 test_accuracy: 0.663 [28, 200] train_loss: 0.080 test_accuracy: 0.670 [28, 400] train_loss: 0.076 test_accuracy: 0.664 [28, 600] train_loss: 0.086 test_accuracy: 0.664 [28, 800] train_loss: 0.099 test_accuracy: 0.672 [28, 1000] train_loss: 0.093 test_accuracy: 0.672 [28, 1200] train_loss: 0.098 test_accuracy: 0.660 [29, 200] train_loss: 0.068 test_accuracy: 0.667 [29, 400] train_loss: 0.072 test_accuracy: 0.668 [29, 600] train_loss: 0.093 test_accuracy: 0.672 [29, 800] train_loss: 0.086 test_accuracy: 0.670 [29, 1000] train_loss: 0.097 test_accuracy: 0.665 [29, 1200] train_loss: 0.091 test_accuracy: 0.658 [30, 200] train_loss: 0.067 test_accuracy: 0.675 [30, 400] train_loss: 0.072 test_accuracy: 0.667 [30, 600] train_loss: 0.080 test_accuracy: 0.665 [30, 800] train_loss: 0.087 test_accuracy: 0.664 [30, 1000] train_loss: 0.091 test_accuracy: 0.668 [30, 1200] train_loss: 0.094 test_accuracy: 0.664 [31, 200] train_loss: 0.073 test_accuracy: 0.669 [31, 400] train_loss: 0.065 test_accuracy: 0.665 [31, 600] train_loss: 0.077 test_accuracy: 0.669 [31, 800] train_loss: 0.074 test_accuracy: 0.665 [31, 1000] train_loss: 0.090 test_accuracy: 0.657 [31, 1200] train_loss: 0.094 test_accuracy: 0.654 [32, 200] train_loss: 0.061 test_accuracy: 0.670 [32, 400] train_loss: 0.075 test_accuracy: 0.664 [32, 600] train_loss: 0.073 test_accuracy: 0.666 [32, 800] train_loss: 0.086 test_accuracy: 0.663 [32, 1000] train_loss: 0.097 test_accuracy: 0.659 [32, 1200] train_loss: 0.085 test_accuracy: 0.662 [33, 200] train_loss: 0.054 test_accuracy: 0.668 [33, 400] train_loss: 0.064 test_accuracy: 0.663 [33, 600] train_loss: 0.059 test_accuracy: 0.661 [33, 800] train_loss: 0.076 test_accuracy: 0.672 [33, 1000] train_loss: 0.086 test_accuracy: 0.663 [33, 1200] train_loss: 0.084 test_accuracy: 0.672 [34, 200] train_loss: 0.056 test_accuracy: 0.663 [34, 400] train_loss: 0.065 test_accuracy: 0.666 [34, 600] train_loss: 0.070 test_accuracy: 0.668 [34, 800] train_loss: 0.082 test_accuracy: 0.663 [34, 1000] train_loss: 0.083 test_accuracy: 0.664 [34, 1200] train_loss: 0.082 test_accuracy: 0.668 [35, 200] train_loss: 0.053 test_accuracy: 0.668 [35, 400] train_loss: 0.053 test_accuracy: 0.664 [35, 600] train_loss: 0.065 test_accuracy: 0.668 [35, 800] train_loss: 0.087 test_accuracy: 0.662 [35, 1000] train_loss: 0.084 test_accuracy: 0.664 [35, 1200] train_loss: 0.085 test_accuracy: 0.660 [36, 200] train_loss: 0.058 test_accuracy: 0.666 [36, 400] train_loss: 0.062 test_accuracy: 0.664 [36, 600] train_loss: 0.062 test_accuracy: 0.669 [36, 800] train_loss: 0.066 test_accuracy: 0.666 [36, 1000] train_loss: 0.076 test_accuracy: 0.656 [36, 1200] train_loss: 0.081 test_accuracy: 0.661 [37, 200] train_loss: 0.052 test_accuracy: 0.668 [37, 400] train_loss: 0.054 test_accuracy: 0.663 [37, 600] train_loss: 0.073 test_accuracy: 0.664 [37, 800] train_loss: 0.071 test_accuracy: 0.660 [37, 1000] train_loss: 0.085 test_accuracy: 0.665 [37, 1200] train_loss: 0.072 test_accuracy: 0.663 [38, 200] train_loss: 0.051 test_accuracy: 0.662 [38, 400] train_loss: 0.065 test_accuracy: 0.657 [38, 600] train_loss: 0.068 test_accuracy: 0.662 [38, 800] train_loss: 0.076 test_accuracy: 0.651 [38, 1000] train_loss: 0.078 test_accuracy: 0.661 [38, 1200] train_loss: 0.073 test_accuracy: 0.661 [39, 200] train_loss: 0.047 test_accuracy: 0.665 [39, 400] train_loss: 0.045 test_accuracy: 0.665 [39, 600] train_loss: 0.056 test_accuracy: 0.657 [39, 800] train_loss: 0.071 test_accuracy: 0.660 [39, 1000] train_loss: 0.075 test_accuracy: 0.664 [39, 1200] train_loss: 0.081 test_accuracy: 0.657 [40, 200] train_loss: 0.050 test_accuracy: 0.668 [40, 400] train_loss: 0.045 test_accuracy: 0.665 [40, 600] train_loss: 0.057 test_accuracy: 0.666 [40, 800] train_loss: 0.066 test_accuracy: 0.655 [40, 1000] train_loss: 0.070 test_accuracy: 0.665 [40, 1200] train_loss: 0.079 test_accuracy: 0.658 [41, 200] train_loss: 0.058 test_accuracy: 0.656 [41, 400] train_loss: 0.054 test_accuracy: 0.658 [41, 600] train_loss: 0.054 test_accuracy: 0.664 [41, 800] train_loss: 0.063 test_accuracy: 0.660 [41, 1000] train_loss: 0.060 test_accuracy: 0.659 [41, 1200] train_loss: 0.075 test_accuracy: 0.650 [42, 200] train_loss: 0.051 test_accuracy: 0.663 [42, 400] train_loss: 0.062 test_accuracy: 0.665 [42, 600] train_loss: 0.050 test_accuracy: 0.656 [42, 800] train_loss: 0.065 test_accuracy: 0.660 [42, 1000] train_loss: 0.058 test_accuracy: 0.658 [42, 1200] train_loss: 0.072 test_accuracy: 0.659 [43, 200] train_loss: 0.052 test_accuracy: 0.666 [43, 400] train_loss: 0.044 test_accuracy: 0.667 [43, 600] train_loss: 0.060 test_accuracy: 0.660 [43, 800] train_loss: 0.063 test_accuracy: 0.661 [43, 1000] train_loss: 0.061 test_accuracy: 0.661 [43, 1200] train_loss: 0.069 test_accuracy: 0.659 [44, 200] train_loss: 0.041 test_accuracy: 0.665 [44, 400] train_loss: 0.052 test_accuracy: 0.657 [44, 600] train_loss: 0.050 test_accuracy: 0.666 [44, 800] train_loss: 0.047 test_accuracy: 0.665 [44, 1000] train_loss: 0.080 test_accuracy: 0.652 [44, 1200] train_loss: 0.075 test_accuracy: 0.652 [45, 200] train_loss: 0.046 test_accuracy: 0.660 [45, 400] train_loss: 0.051 test_accuracy: 0.665 [45, 600] train_loss: 0.055 test_accuracy: 0.666 [45, 800] train_loss: 0.055 test_accuracy: 0.664 [45, 1000] train_loss: 0.065 test_accuracy: 0.657 [45, 1200] train_loss: 0.068 test_accuracy: 0.661 [46, 200] train_loss: 0.044 test_accuracy: 0.655 [46, 400] train_loss: 0.043 test_accuracy: 0.668 [46, 600] train_loss: 0.047 test_accuracy: 0.663 [46, 800] train_loss: 0.045 test_accuracy: 0.664 [46, 1000] train_loss: 0.069 test_accuracy: 0.653 [46, 1200] train_loss: 0.074 test_accuracy: 0.653 [47, 200] train_loss: 0.049 test_accuracy: 0.656 [47, 400] train_loss: 0.045 test_accuracy: 0.658 [47, 600] train_loss: 0.058 test_accuracy: 0.655 [47, 800] train_loss: 0.057 test_accuracy: 0.657 [47, 1000] train_loss: 0.056 test_accuracy: 0.663 [47, 1200] train_loss: 0.058 test_accuracy: 0.662 [48, 200] train_loss: 0.043 test_accuracy: 0.661 [48, 400] train_loss: 0.053 test_accuracy: 0.655 [48, 600] train_loss: 0.062 test_accuracy: 0.654 [48, 800] train_loss: 0.063 test_accuracy: 0.655 [48, 1000] train_loss: 0.062 test_accuracy: 0.661 [48, 1200] train_loss: 0.062 test_accuracy: 0.661 [49, 200] train_loss: 0.049 test_accuracy: 0.659 [49, 400] train_loss: 0.046 test_accuracy: 0.663 [49, 600] train_loss: 0.045 test_accuracy: 0.657 [49, 800] train_loss: 0.043 test_accuracy: 0.666 [49, 1000] train_loss: 0.057 test_accuracy: 0.655 [49, 1200] train_loss: 0.070 test_accuracy: 0.660 [50, 200] train_loss: 0.035 test_accuracy: 0.656 [50, 400] train_loss: 0.038 test_accuracy: 0.661 [50, 600] train_loss: 0.043 test_accuracy: 0.663 [50, 800] train_loss: 0.047 test_accuracy: 0.659 [50, 1000] train_loss: 0.064 test_accuracy: 0.661 [50, 1200] train_loss: 0.071 test_accuracy: 0.651 Finished Training!
3.1 测试¶
In [1]:
import torch
from torchvision.transforms import transforms
from PIL import Image
from model import LeNet
transform = transforms.Compose([transforms.Resize((32,32)),
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
net = LeNet()
net.load_state_dict(torch.load('LeNet.pth'))
img = Image.open(r'C:/Users/lds/Desktop/test.webp')
# [C, H, W]
img = transform(img)
# [B, C, H, W]
img = torch.unsqueeze(img, dim=0)
with torch.no_grad():
outputs = net(img)
predict = torch.max(outputs, dim=1)[1].data.numpy()
title = classes[int(predict)]
import matplotlib.pyplot as plt
import numpy as np
# 去除批次维度
image = np.squeeze(img, 0)
# unnormalize
plt.imshow(np.transpose(image, (1,2,0)) / 2 + 0.5)
# IMG show
plt.yticks([])
plt.xticks([])
plt.xlabel(title)
plt.show()