1.1 下载数据集¶

In [2]:
import torch
torch.cuda.is_available()
Out[2]:
True
In [16]:
from model import *
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import *


transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_set = torchvision.datasets.CIFAR10(root='../../../data/',
                                        train=True,
                                        download=False,
                                        transform=transform)

train_loader = DataLoader(train_set, batch_size=36, shuffle=True, num_workers=0)
In [39]:
val_set = torchvision.datasets.CIFAR10(root='../../../data/',
                                        train=False,
                                        download=False,
                                        transform=transform)

val_loader = DataLoader(val_set, batch_size=10000, shuffle=True, num_workers=0)
In [54]:
var_data_iter = iter(val_loader)

val_img, val_label = var_data_iter.next()

1.2 看一下数据集¶

In [53]:
import numpy as np
import matplotlib.pyplot as plt

def imshow(img):
    img =  img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)))
    plt.show()
    
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
test_data_iter = iter(val_loader)

test_img, test_label = test_data_iter.next()
test_img = test_img[:10]; test_label = test_label[:10]
imshow(torchvision.utils.make_grid(test_img))

print(list(map(lambda x:classes[x], test_label.numpy())))
No description has been provided for this image
['horse', 'deer', 'bird', 'ship', 'cat', 'frog', 'car', 'truck', 'frog', 'ship']

2.1 开始训练¶

In [58]:
import torch.optim as optim 

net = LeNet()
loss_function = nn.CrossEntropyLoss()

# 优化器
optimizer = optim.Adam(net.parameters(), lr = 0.001)

for epoch in range(50):
    running_loss = 0.0
    for step, data in enumerate(train_loader, start=0):
        inputs, labels = data
        
        optimizer.zero_grad()
        
        outputs = net(inputs)
        
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if step % 200 == 199:
            with torch.no_grad():
                outputs = net(val_img)
                pridict_y = torch.max(outputs, dim=1)[1]
                acc = torch.eq(pridict_y, val_label).sum().item() / val_label.size(0)
                
                print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %
                          (epoch + 1, step + 1, running_loss / 500, acc))
                running_loss = 0.0

print("Finished Training!")
save_path = "./LeNet.pth"
torch.save(net.state_dict(), save_path)
[1,   200] train_loss: 0.776  test_accuracy: 0.346
[1,   400] train_loss: 0.654  test_accuracy: 0.432
[1,   600] train_loss: 0.596  test_accuracy: 0.462
[1,   800] train_loss: 0.579  test_accuracy: 0.497
[1,  1000] train_loss: 0.550  test_accuracy: 0.505
[1,  1200] train_loss: 0.526  test_accuracy: 0.527
[2,   200] train_loss: 0.501  test_accuracy: 0.547
[2,   400] train_loss: 0.486  test_accuracy: 0.554
[2,   600] train_loss: 0.478  test_accuracy: 0.580
[2,   800] train_loss: 0.470  test_accuracy: 0.576
[2,  1000] train_loss: 0.459  test_accuracy: 0.590
[2,  1200] train_loss: 0.446  test_accuracy: 0.594
[3,   200] train_loss: 0.427  test_accuracy: 0.610
[3,   400] train_loss: 0.415  test_accuracy: 0.615
[3,   600] train_loss: 0.415  test_accuracy: 0.609
[3,   800] train_loss: 0.411  test_accuracy: 0.623
[3,  1000] train_loss: 0.404  test_accuracy: 0.630
[3,  1200] train_loss: 0.406  test_accuracy: 0.638
[4,   200] train_loss: 0.361  test_accuracy: 0.646
[4,   400] train_loss: 0.365  test_accuracy: 0.646
[4,   600] train_loss: 0.373  test_accuracy: 0.634
[4,   800] train_loss: 0.373  test_accuracy: 0.640
[4,  1000] train_loss: 0.370  test_accuracy: 0.653
[4,  1200] train_loss: 0.359  test_accuracy: 0.660
[5,   200] train_loss: 0.333  test_accuracy: 0.662
[5,   400] train_loss: 0.334  test_accuracy: 0.656
[5,   600] train_loss: 0.338  test_accuracy: 0.653
[5,   800] train_loss: 0.336  test_accuracy: 0.657
[5,  1000] train_loss: 0.348  test_accuracy: 0.660
[5,  1200] train_loss: 0.339  test_accuracy: 0.673
[6,   200] train_loss: 0.300  test_accuracy: 0.657
[6,   400] train_loss: 0.314  test_accuracy: 0.672
[6,   600] train_loss: 0.314  test_accuracy: 0.658
[6,   800] train_loss: 0.314  test_accuracy: 0.677
[6,  1000] train_loss: 0.317  test_accuracy: 0.677
[6,  1200] train_loss: 0.307  test_accuracy: 0.670
[7,   200] train_loss: 0.277  test_accuracy: 0.678
[7,   400] train_loss: 0.282  test_accuracy: 0.675
[7,   600] train_loss: 0.285  test_accuracy: 0.668
[7,   800] train_loss: 0.296  test_accuracy: 0.676
[7,  1000] train_loss: 0.302  test_accuracy: 0.679
[7,  1200] train_loss: 0.299  test_accuracy: 0.684
[8,   200] train_loss: 0.258  test_accuracy: 0.671
[8,   400] train_loss: 0.265  test_accuracy: 0.679
[8,   600] train_loss: 0.265  test_accuracy: 0.675
[8,   800] train_loss: 0.274  test_accuracy: 0.690
[8,  1000] train_loss: 0.268  test_accuracy: 0.683
[8,  1200] train_loss: 0.275  test_accuracy: 0.693
[9,   200] train_loss: 0.241  test_accuracy: 0.693
[9,   400] train_loss: 0.244  test_accuracy: 0.680
[9,   600] train_loss: 0.261  test_accuracy: 0.682
[9,   800] train_loss: 0.245  test_accuracy: 0.679
[9,  1000] train_loss: 0.256  test_accuracy: 0.683
[9,  1200] train_loss: 0.256  test_accuracy: 0.686
[10,   200] train_loss: 0.222  test_accuracy: 0.679
[10,   400] train_loss: 0.233  test_accuracy: 0.688
[10,   600] train_loss: 0.237  test_accuracy: 0.678
[10,   800] train_loss: 0.244  test_accuracy: 0.689
[10,  1000] train_loss: 0.241  test_accuracy: 0.689
[10,  1200] train_loss: 0.241  test_accuracy: 0.689
[11,   200] train_loss: 0.201  test_accuracy: 0.685
[11,   400] train_loss: 0.218  test_accuracy: 0.684
[11,   600] train_loss: 0.223  test_accuracy: 0.692
[11,   800] train_loss: 0.226  test_accuracy: 0.691
[11,  1000] train_loss: 0.224  test_accuracy: 0.682
[11,  1200] train_loss: 0.229  test_accuracy: 0.689
[12,   200] train_loss: 0.185  test_accuracy: 0.687
[12,   400] train_loss: 0.204  test_accuracy: 0.681
[12,   600] train_loss: 0.203  test_accuracy: 0.689
[12,   800] train_loss: 0.217  test_accuracy: 0.687
[12,  1000] train_loss: 0.217  test_accuracy: 0.685
[12,  1200] train_loss: 0.220  test_accuracy: 0.689
[13,   200] train_loss: 0.181  test_accuracy: 0.684
[13,   400] train_loss: 0.191  test_accuracy: 0.683
[13,   600] train_loss: 0.200  test_accuracy: 0.678
[13,   800] train_loss: 0.201  test_accuracy: 0.679
[13,  1000] train_loss: 0.198  test_accuracy: 0.681
[13,  1200] train_loss: 0.201  test_accuracy: 0.686
[14,   200] train_loss: 0.167  test_accuracy: 0.687
[14,   400] train_loss: 0.178  test_accuracy: 0.683
[14,   600] train_loss: 0.178  test_accuracy: 0.681
[14,   800] train_loss: 0.191  test_accuracy: 0.678
[14,  1000] train_loss: 0.187  test_accuracy: 0.682
[14,  1200] train_loss: 0.196  test_accuracy: 0.687
[15,   200] train_loss: 0.156  test_accuracy: 0.688
[15,   400] train_loss: 0.160  test_accuracy: 0.684
[15,   600] train_loss: 0.169  test_accuracy: 0.684
[15,   800] train_loss: 0.185  test_accuracy: 0.678
[15,  1000] train_loss: 0.180  test_accuracy: 0.670
[15,  1200] train_loss: 0.182  test_accuracy: 0.683
[16,   200] train_loss: 0.147  test_accuracy: 0.679
[16,   400] train_loss: 0.159  test_accuracy: 0.683
[16,   600] train_loss: 0.161  test_accuracy: 0.679
[16,   800] train_loss: 0.163  test_accuracy: 0.678
[16,  1000] train_loss: 0.167  test_accuracy: 0.683
[16,  1200] train_loss: 0.186  test_accuracy: 0.678
[17,   200] train_loss: 0.139  test_accuracy: 0.687
[17,   400] train_loss: 0.147  test_accuracy: 0.679
[17,   600] train_loss: 0.157  test_accuracy: 0.679
[17,   800] train_loss: 0.156  test_accuracy: 0.681
[17,  1000] train_loss: 0.161  test_accuracy: 0.675
[17,  1200] train_loss: 0.162  test_accuracy: 0.675
[18,   200] train_loss: 0.123  test_accuracy: 0.678
[18,   400] train_loss: 0.132  test_accuracy: 0.682
[18,   600] train_loss: 0.141  test_accuracy: 0.673
[18,   800] train_loss: 0.142  test_accuracy: 0.667
[18,  1000] train_loss: 0.160  test_accuracy: 0.679
[18,  1200] train_loss: 0.151  test_accuracy: 0.677
[19,   200] train_loss: 0.122  test_accuracy: 0.673
[19,   400] train_loss: 0.128  test_accuracy: 0.680
[19,   600] train_loss: 0.136  test_accuracy: 0.682
[19,   800] train_loss: 0.142  test_accuracy: 0.682
[19,  1000] train_loss: 0.148  test_accuracy: 0.674
[19,  1200] train_loss: 0.146  test_accuracy: 0.679
[20,   200] train_loss: 0.115  test_accuracy: 0.675
[20,   400] train_loss: 0.117  test_accuracy: 0.672
[20,   600] train_loss: 0.128  test_accuracy: 0.670
[20,   800] train_loss: 0.127  test_accuracy: 0.674
[20,  1000] train_loss: 0.136  test_accuracy: 0.671
[20,  1200] train_loss: 0.146  test_accuracy: 0.668
[21,   200] train_loss: 0.105  test_accuracy: 0.673
[21,   400] train_loss: 0.120  test_accuracy: 0.679
[21,   600] train_loss: 0.116  test_accuracy: 0.677
[21,   800] train_loss: 0.127  test_accuracy: 0.677
[21,  1000] train_loss: 0.134  test_accuracy: 0.674
[21,  1200] train_loss: 0.131  test_accuracy: 0.669
[22,   200] train_loss: 0.096  test_accuracy: 0.665
[22,   400] train_loss: 0.106  test_accuracy: 0.675
[22,   600] train_loss: 0.117  test_accuracy: 0.671
[22,   800] train_loss: 0.117  test_accuracy: 0.684
[22,  1000] train_loss: 0.115  test_accuracy: 0.675
[22,  1200] train_loss: 0.130  test_accuracy: 0.677
[23,   200] train_loss: 0.098  test_accuracy: 0.676
[23,   400] train_loss: 0.107  test_accuracy: 0.671
[23,   600] train_loss: 0.107  test_accuracy: 0.673
[23,   800] train_loss: 0.113  test_accuracy: 0.670
[23,  1000] train_loss: 0.116  test_accuracy: 0.670
[23,  1200] train_loss: 0.121  test_accuracy: 0.671
[24,   200] train_loss: 0.082  test_accuracy: 0.672
[24,   400] train_loss: 0.096  test_accuracy: 0.674
[24,   600] train_loss: 0.110  test_accuracy: 0.669
[24,   800] train_loss: 0.115  test_accuracy: 0.675
[24,  1000] train_loss: 0.107  test_accuracy: 0.671
[24,  1200] train_loss: 0.121  test_accuracy: 0.673
[25,   200] train_loss: 0.082  test_accuracy: 0.674
[25,   400] train_loss: 0.092  test_accuracy: 0.673
[25,   600] train_loss: 0.099  test_accuracy: 0.673
[25,   800] train_loss: 0.105  test_accuracy: 0.672
[25,  1000] train_loss: 0.101  test_accuracy: 0.673
[25,  1200] train_loss: 0.108  test_accuracy: 0.674
[26,   200] train_loss: 0.085  test_accuracy: 0.670
[26,   400] train_loss: 0.087  test_accuracy: 0.664
[26,   600] train_loss: 0.100  test_accuracy: 0.658
[26,   800] train_loss: 0.095  test_accuracy: 0.670
[26,  1000] train_loss: 0.101  test_accuracy: 0.667
[26,  1200] train_loss: 0.110  test_accuracy: 0.663
[27,   200] train_loss: 0.076  test_accuracy: 0.671
[27,   400] train_loss: 0.081  test_accuracy: 0.671
[27,   600] train_loss: 0.091  test_accuracy: 0.662
[27,   800] train_loss: 0.093  test_accuracy: 0.663
[27,  1000] train_loss: 0.093  test_accuracy: 0.670
[27,  1200] train_loss: 0.098  test_accuracy: 0.663
[28,   200] train_loss: 0.080  test_accuracy: 0.670
[28,   400] train_loss: 0.076  test_accuracy: 0.664
[28,   600] train_loss: 0.086  test_accuracy: 0.664
[28,   800] train_loss: 0.099  test_accuracy: 0.672
[28,  1000] train_loss: 0.093  test_accuracy: 0.672
[28,  1200] train_loss: 0.098  test_accuracy: 0.660
[29,   200] train_loss: 0.068  test_accuracy: 0.667
[29,   400] train_loss: 0.072  test_accuracy: 0.668
[29,   600] train_loss: 0.093  test_accuracy: 0.672
[29,   800] train_loss: 0.086  test_accuracy: 0.670
[29,  1000] train_loss: 0.097  test_accuracy: 0.665
[29,  1200] train_loss: 0.091  test_accuracy: 0.658
[30,   200] train_loss: 0.067  test_accuracy: 0.675
[30,   400] train_loss: 0.072  test_accuracy: 0.667
[30,   600] train_loss: 0.080  test_accuracy: 0.665
[30,   800] train_loss: 0.087  test_accuracy: 0.664
[30,  1000] train_loss: 0.091  test_accuracy: 0.668
[30,  1200] train_loss: 0.094  test_accuracy: 0.664
[31,   200] train_loss: 0.073  test_accuracy: 0.669
[31,   400] train_loss: 0.065  test_accuracy: 0.665
[31,   600] train_loss: 0.077  test_accuracy: 0.669
[31,   800] train_loss: 0.074  test_accuracy: 0.665
[31,  1000] train_loss: 0.090  test_accuracy: 0.657
[31,  1200] train_loss: 0.094  test_accuracy: 0.654
[32,   200] train_loss: 0.061  test_accuracy: 0.670
[32,   400] train_loss: 0.075  test_accuracy: 0.664
[32,   600] train_loss: 0.073  test_accuracy: 0.666
[32,   800] train_loss: 0.086  test_accuracy: 0.663
[32,  1000] train_loss: 0.097  test_accuracy: 0.659
[32,  1200] train_loss: 0.085  test_accuracy: 0.662
[33,   200] train_loss: 0.054  test_accuracy: 0.668
[33,   400] train_loss: 0.064  test_accuracy: 0.663
[33,   600] train_loss: 0.059  test_accuracy: 0.661
[33,   800] train_loss: 0.076  test_accuracy: 0.672
[33,  1000] train_loss: 0.086  test_accuracy: 0.663
[33,  1200] train_loss: 0.084  test_accuracy: 0.672
[34,   200] train_loss: 0.056  test_accuracy: 0.663
[34,   400] train_loss: 0.065  test_accuracy: 0.666
[34,   600] train_loss: 0.070  test_accuracy: 0.668
[34,   800] train_loss: 0.082  test_accuracy: 0.663
[34,  1000] train_loss: 0.083  test_accuracy: 0.664
[34,  1200] train_loss: 0.082  test_accuracy: 0.668
[35,   200] train_loss: 0.053  test_accuracy: 0.668
[35,   400] train_loss: 0.053  test_accuracy: 0.664
[35,   600] train_loss: 0.065  test_accuracy: 0.668
[35,   800] train_loss: 0.087  test_accuracy: 0.662
[35,  1000] train_loss: 0.084  test_accuracy: 0.664
[35,  1200] train_loss: 0.085  test_accuracy: 0.660
[36,   200] train_loss: 0.058  test_accuracy: 0.666
[36,   400] train_loss: 0.062  test_accuracy: 0.664
[36,   600] train_loss: 0.062  test_accuracy: 0.669
[36,   800] train_loss: 0.066  test_accuracy: 0.666
[36,  1000] train_loss: 0.076  test_accuracy: 0.656
[36,  1200] train_loss: 0.081  test_accuracy: 0.661
[37,   200] train_loss: 0.052  test_accuracy: 0.668
[37,   400] train_loss: 0.054  test_accuracy: 0.663
[37,   600] train_loss: 0.073  test_accuracy: 0.664
[37,   800] train_loss: 0.071  test_accuracy: 0.660
[37,  1000] train_loss: 0.085  test_accuracy: 0.665
[37,  1200] train_loss: 0.072  test_accuracy: 0.663
[38,   200] train_loss: 0.051  test_accuracy: 0.662
[38,   400] train_loss: 0.065  test_accuracy: 0.657
[38,   600] train_loss: 0.068  test_accuracy: 0.662
[38,   800] train_loss: 0.076  test_accuracy: 0.651
[38,  1000] train_loss: 0.078  test_accuracy: 0.661
[38,  1200] train_loss: 0.073  test_accuracy: 0.661
[39,   200] train_loss: 0.047  test_accuracy: 0.665
[39,   400] train_loss: 0.045  test_accuracy: 0.665
[39,   600] train_loss: 0.056  test_accuracy: 0.657
[39,   800] train_loss: 0.071  test_accuracy: 0.660
[39,  1000] train_loss: 0.075  test_accuracy: 0.664
[39,  1200] train_loss: 0.081  test_accuracy: 0.657
[40,   200] train_loss: 0.050  test_accuracy: 0.668
[40,   400] train_loss: 0.045  test_accuracy: 0.665
[40,   600] train_loss: 0.057  test_accuracy: 0.666
[40,   800] train_loss: 0.066  test_accuracy: 0.655
[40,  1000] train_loss: 0.070  test_accuracy: 0.665
[40,  1200] train_loss: 0.079  test_accuracy: 0.658
[41,   200] train_loss: 0.058  test_accuracy: 0.656
[41,   400] train_loss: 0.054  test_accuracy: 0.658
[41,   600] train_loss: 0.054  test_accuracy: 0.664
[41,   800] train_loss: 0.063  test_accuracy: 0.660
[41,  1000] train_loss: 0.060  test_accuracy: 0.659
[41,  1200] train_loss: 0.075  test_accuracy: 0.650
[42,   200] train_loss: 0.051  test_accuracy: 0.663
[42,   400] train_loss: 0.062  test_accuracy: 0.665
[42,   600] train_loss: 0.050  test_accuracy: 0.656
[42,   800] train_loss: 0.065  test_accuracy: 0.660
[42,  1000] train_loss: 0.058  test_accuracy: 0.658
[42,  1200] train_loss: 0.072  test_accuracy: 0.659
[43,   200] train_loss: 0.052  test_accuracy: 0.666
[43,   400] train_loss: 0.044  test_accuracy: 0.667
[43,   600] train_loss: 0.060  test_accuracy: 0.660
[43,   800] train_loss: 0.063  test_accuracy: 0.661
[43,  1000] train_loss: 0.061  test_accuracy: 0.661
[43,  1200] train_loss: 0.069  test_accuracy: 0.659
[44,   200] train_loss: 0.041  test_accuracy: 0.665
[44,   400] train_loss: 0.052  test_accuracy: 0.657
[44,   600] train_loss: 0.050  test_accuracy: 0.666
[44,   800] train_loss: 0.047  test_accuracy: 0.665
[44,  1000] train_loss: 0.080  test_accuracy: 0.652
[44,  1200] train_loss: 0.075  test_accuracy: 0.652
[45,   200] train_loss: 0.046  test_accuracy: 0.660
[45,   400] train_loss: 0.051  test_accuracy: 0.665
[45,   600] train_loss: 0.055  test_accuracy: 0.666
[45,   800] train_loss: 0.055  test_accuracy: 0.664
[45,  1000] train_loss: 0.065  test_accuracy: 0.657
[45,  1200] train_loss: 0.068  test_accuracy: 0.661
[46,   200] train_loss: 0.044  test_accuracy: 0.655
[46,   400] train_loss: 0.043  test_accuracy: 0.668
[46,   600] train_loss: 0.047  test_accuracy: 0.663
[46,   800] train_loss: 0.045  test_accuracy: 0.664
[46,  1000] train_loss: 0.069  test_accuracy: 0.653
[46,  1200] train_loss: 0.074  test_accuracy: 0.653
[47,   200] train_loss: 0.049  test_accuracy: 0.656
[47,   400] train_loss: 0.045  test_accuracy: 0.658
[47,   600] train_loss: 0.058  test_accuracy: 0.655
[47,   800] train_loss: 0.057  test_accuracy: 0.657
[47,  1000] train_loss: 0.056  test_accuracy: 0.663
[47,  1200] train_loss: 0.058  test_accuracy: 0.662
[48,   200] train_loss: 0.043  test_accuracy: 0.661
[48,   400] train_loss: 0.053  test_accuracy: 0.655
[48,   600] train_loss: 0.062  test_accuracy: 0.654
[48,   800] train_loss: 0.063  test_accuracy: 0.655
[48,  1000] train_loss: 0.062  test_accuracy: 0.661
[48,  1200] train_loss: 0.062  test_accuracy: 0.661
[49,   200] train_loss: 0.049  test_accuracy: 0.659
[49,   400] train_loss: 0.046  test_accuracy: 0.663
[49,   600] train_loss: 0.045  test_accuracy: 0.657
[49,   800] train_loss: 0.043  test_accuracy: 0.666
[49,  1000] train_loss: 0.057  test_accuracy: 0.655
[49,  1200] train_loss: 0.070  test_accuracy: 0.660
[50,   200] train_loss: 0.035  test_accuracy: 0.656
[50,   400] train_loss: 0.038  test_accuracy: 0.661
[50,   600] train_loss: 0.043  test_accuracy: 0.663
[50,   800] train_loss: 0.047  test_accuracy: 0.659
[50,  1000] train_loss: 0.064  test_accuracy: 0.661
[50,  1200] train_loss: 0.071  test_accuracy: 0.651
Finished Training!

3.1 测试¶

In [1]:
import torch
from torchvision.transforms import transforms
from PIL import Image

from model import LeNet

transform = transforms.Compose([transforms.Resize((32,32)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])

classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
net = LeNet()
net.load_state_dict(torch.load('LeNet.pth'))

img = Image.open(r'C:/Users/lds/Desktop/test.webp')
# [C, H, W]
img = transform(img)


# [B, C, H, W]
img = torch.unsqueeze(img, dim=0)

with torch.no_grad():
    outputs = net(img)

predict = torch.max(outputs, dim=1)[1].data.numpy()
title = classes[int(predict)]

import matplotlib.pyplot as plt
import numpy as np

# 去除批次维度
image = np.squeeze(img, 0)
# unnormalize
plt.imshow(np.transpose(image, (1,2,0)) / 2 + 0.5)
# IMG show
plt.yticks([])
plt.xticks([])
plt.xlabel(title)
plt.show()
No description has been provided for this image