UP:B站-刘二大人
原视频链接:11.卷积神经网络(高级篇)_哔哩哔哩_bilibili
'''
GoogleNet(本代码不是标准的Googlenet结构,只是跟随视频搭建的简化网络)
每个Inception结构:
concatenate:把张量沿着通道拼接到一起
要求四条路径出来的结果的W和H一样,3*3和5*5的路径可以通过padding实现,池化层可以人为去设置步长和padding让池化后的大小不变
1*1的卷积核:卷积核大小为1x1的卷积层,目的是为了降维(减小深度),减少模型训练参数,减少计算量
'''
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
batch_size = 64
transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.1037, ),(0.3081, )) ] )
train_dataset = datasets.MNIST(root="../dataset/mnist", train=True, download=True,transform=transform)
train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size)
test_dataset = datasets.MNIST(root="../dataset/mnist", train=False, download=True, transform=transform)
test_loader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=batch_size)
class InceptionA(torch.nn.Module):
def __init__(self, in_channels):
super(InceptionA, self).__init__()
# 四个分支的第一个分支:3*3的平均池化+24个1*1的卷积核
self.branch_pool = torch.nn.Conv2d(in_channels, 24, kernel_size=1)
# 四个分支的第二个分支:16个1*1的卷积核
self.branch1x1 = torch.nn.Conv2d(in_channels, 16, kernel_size=1)
# 四个分支的第三个分支:16个1*1的卷积核+24个5*5的卷积核
self.branch5x5_1 = torch.nn.Conv2d(in_channels, 16, kernel_size=1)
self.branch5x5_2 = torch.nn.Conv2d(16, 24, kernel_size=5, padding=2)
# 四个分支的第四个分支:16个1*1的卷积核+24个3*3的卷积核+24个3*3的卷积核
self.branch3x3_1 = torch.nn.Conv2d(in_channels, 16, kernel_size=1)
self.branch3x3_2 = torch.nn.Conv2d(16, 24, kernel_size=3, padding=1)
self.branch3x3_3 = torch.nn.Conv2d(24, 24, kernel_size=3, padding=1)
def forward(self, x):
# 第一个分支
branch_pool = self.branch_pool( torch.nn.functional.avg_pool2d(x, kernel_size=3, stride=1, padding=1) )
# 第二个分支
branch1x1 = self.branch1x1(x)
# 第三个分支
branch5x5 = self.branch5x5_2(self.branch5x5_1(x))
# 第四个分支
branch3x3 = self.branch3x3_3(self.branch3x3_2(self.branch3x3_1(x)))
outputs = [branch_pool, branch1x1, branch5x5, branch3x3]
# 张量的维度是(b,c,w,h)即:batch_size,通道数,宽度,高度,dim=1即沿着第一个维度通道进行拼接
return torch.cat(outputs, dim=1)
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 10 ,kernel_size=5)
# InceptionA的四个分支输出的通道数分别是24,16,24,24,一共88个输出通道
self.conv2 = torch.nn.Conv2d(88, 20, kernel_size=5)
self.incep1 = InceptionA(in_channels=10)
self.incep2 = InceptionA(in_channels=20)
self.maxpooling = torch.nn.MaxPool2d(2)
self.active = torch.nn.ReLU()
# 原始图像是1*28*28,经过conv1变成10*24*24,经过最大池化变成10*12*12,经过incep1变成88*12*12
# 再经过conv2变成20*8*8,经过最大池化变成20*4*4,再经过incep2变成88*4*4
self.fc1 = torch.nn.Linear(1408, 512) # 88*4*4=1408
self.fc2 = torch.nn.Linear(512, 128)
self.fc3 = torch.nn.Linear(128, 32)
self.fc4 = torch.nn.Linear(32, 10)
def forward(self, x):
x = self.maxpooling(self.active(self.conv1(x)))
x = self.incep1(x)
x = self.maxpooling(self.active(self.conv2(x)))
x = self.incep2(x)
# 此时的x是维度为4的tensor,即(batch_size, c , H, W),x.size(0)指的是batch_size的值
x = x.view(x.size(0), -1)
x = self.fc4(self.fc3(self.fc2(self.fc1(x))))
return x
model = Model()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
loss_list = []
def train(epoch):
loss_sum = 0.0
for i,(inputs, labels) in enumerate(train_loader):
outputs = model(inputs)
loss = criterion(outputs, labels)
loss_sum += loss.item()
loss_list.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 300 == 299:
print("[%d %5d] loss:%.3f" % (epoch+1, i+1, loss_sum/i))
#if __name__=="__main__":
#for epoch in range(5):
#train(epoch)
accuracy = []
def test():
total = 0
correct = 0
for (images, targets) in test_loader:
y_pred = model(images)
_,predicted =torch.max(y_pred.data, dim=1)
total += targets.size(0)
correct += (predicted == targets).sum().item()
print("Accuracy on test data:%.3f %% [%d %5d]" % (100*correct/total, correct, total))
accuracy.append(100*correct/total)
if __name__=="__main__":
for epoch in range(10):
train(epoch)
test()
plt.subplot(121)
plt.plot(range(len(loss_list)), loss_list)
plt.xlabel("step")
plt.ylabel("loss")
plt.subplot(122)
plt.plot(range(epoch+1), accuracy)
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.show()